Follow up to #103: pull context from r.Context() when used rather than at the start of functions

This commit is contained in:
David Dworken 2023-09-07 07:56:03 -07:00
parent 86c0acfbc8
commit 68e3a813c9
No known key found for this signature in database

View File

@ -68,23 +68,23 @@ func getHishtoryVersion(r *http.Request) string {
return r.Header.Get("X-Hishtory-Version") return r.Header.Get("X-Hishtory-Version")
} }
func updateUsageData(ctx context.Context, r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) { func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) {
var usageData []UsageData var usageData []UsageData
GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData) GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
if len(usageData) == 0 { if len(usageData) == 0 {
GLOBAL_DB.WithContext(ctx).Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntriesHandled, Version: getHishtoryVersion(r)}) GLOBAL_DB.WithContext(r.Context()).Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntriesHandled, Version: getHishtoryVersion(r)})
} else { } else {
usage := usageData[0] usage := usageData[0]
GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r)) GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r))
if numEntriesHandled > 0 { if numEntriesHandled > 0 {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId) GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId)
} }
if usage.Version != getHishtoryVersion(r) { if usage.Version != getHishtoryVersion(r) {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", getHishtoryVersion(r), userId, deviceId) GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", getHishtoryVersion(r), userId, deviceId)
} }
} }
if isQuery { if isQuery {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userId, deviceId) GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userId, deviceId)
} }
} }
@ -132,24 +132,23 @@ func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
} }
func statsHandler(w http.ResponseWriter, r *http.Request) { func statsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var numDevices int64 = 0 var numDevices int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&numDevices))
type numEntriesProcessed struct { type numEntriesProcessed struct {
Total int Total int
} }
nep := numEntriesProcessed{} nep := numEntriesProcessed{}
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep))
var numDbEntries int64 = 0 var numDbEntries int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries))
lastWeek := time.Now().AddDate(0, 0, -7) lastWeek := time.Now().AddDate(0, 0, -7)
var weeklyActiveInstalls int64 = 0 var weeklyActiveInstalls int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls))
var weeklyQueryUsers int64 = 0 var weeklyQueryUsers int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers))
var lastRegistration string = "" var lastRegistration string = ""
row := GLOBAL_DB.WithContext(ctx).Raw("select to_char(max(registration_date), 'DD Month YYYY HH24:MI') from devices").Row() row := GLOBAL_DB.WithContext(r.Context()).Raw("select to_char(max(registration_date), 'DD Month YYYY HH24:MI') from devices").Row()
err := row.Scan(&lastRegistration) err := row.Scan(&lastRegistration)
if err != nil { if err != nil {
panic(err) panic(err)
@ -163,7 +162,6 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
} }
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
panic(err) panic(err)
@ -177,15 +175,15 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
if len(entries) == 0 { if len(entries) == 0 {
return return
} }
updateUsageData(ctx, r, entries[0].UserId, entries[0].DeviceId, len(entries), false) updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false)
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", entries[0].UserId) tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", entries[0].UserId)
var devices []*shared.Device var devices []*shared.Device
checkGormResult(tx.Find(&devices)) checkGormResult(tx.Find(&devices))
if len(devices) == 0 { if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
} }
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
for _, device := range devices { for _, device := range devices {
for _, entry := range entries { for _, entry := range entries {
entry.DeviceId = device.DeviceId entry.DeviceId = device.DeviceId
@ -209,11 +207,10 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
} }
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(ctx, r, userId, deviceId, 0, false) updateUsageData(r, userId, deviceId, 0, false)
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", userId) tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries)) checkGormResult(tx.Find(&historyEntries))
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
@ -228,20 +225,20 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(ctx, r, userId, deviceId, 0, true) updateUsageData(r, userId, deviceId, 0, true)
// Delete any entries that match a pending deletion request // Delete any entries that match a pending deletion request
var deletionRequests []*shared.DeletionRequest var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(ctx).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
for _, request := range deletionRequests { for _, request := range deletionRequests {
_, err := applyDeletionRequestsToBackend(ctx, *request) _, err := applyDeletionRequestsToBackend(r.Context(), *request)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
// Then retrieve // Then retrieve
tx := GLOBAL_DB.WithContext(ctx).Where("device_id = ? AND read_count < 5", deviceId) tx := GLOBAL_DB.WithContext(r.Context()).Where("device_id = ? AND read_count < 5", deviceId)
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries)) checkGormResult(tx.Find(&historyEntries))
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
@ -284,9 +281,8 @@ func getRemoteAddr(r *http.Request) string {
} }
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if getMaximumNumberOfAllowedUsers() < math.MaxInt { if getMaximumNumberOfAllowedUsers() < math.MaxInt {
row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row() row := GLOBAL_DB.WithContext(r.Context()).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
var numDistinctUsers int64 = 0 var numDistinctUsers int64 = 0
err := row.Scan(&numDistinctUsers) err := row.Scan(&numDistinctUsers)
if err != nil { if err != nil {
@ -299,13 +295,13 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
var existingDevicesCount int64 = -1 var existingDevicesCount int64 = -1
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount))
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()})) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}))
if existingDevicesCount > 0 { if existingDevicesCount > 0 {
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}))
} }
updateUsageData(ctx, r, userId, deviceId, 0, false) updateUsageData(r, userId, deviceId, 0, false)
if GLOBAL_STATSD != nil { if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0) GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
@ -329,7 +325,6 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
} }
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
srcDeviceId := getRequiredQueryParam(r, "source_device_id") srcDeviceId := getRequiredQueryParam(r, "source_device_id")
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
@ -343,7 +338,7 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
panic(fmt.Sprintf("body=%#v, err=%v", data, err)) panic(fmt.Sprintf("body=%#v, err=%v", data, err))
} }
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
err = GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
for _, entry := range entries { for _, entry := range entries {
entry.DeviceId = requestingDeviceId entry.DeviceId = requestingDeviceId
if entry.UserId != userId { if entry.UserId != userId {
@ -356,8 +351,8 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err)) panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err))
} }
checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false) updateUsageData(r, userId, srcDeviceId, len(entries), false)
w.Header().Set("Content-Length", "0") w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -376,16 +371,15 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
} }
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount // Increment the ReadCount
checkGormResult(GLOBAL_DB.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
// Return all the deletion requests // Return all the deletion requests
var deletionRequests []*shared.DeletionRequest var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
respBody, err := json.Marshal(deletionRequests) respBody, err := json.Marshal(deletionRequests)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
@ -394,7 +388,6 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
} }
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
panic(err) panic(err)
@ -408,7 +401,7 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
// Store the deletion request so all the devices will get it // Store the deletion request so all the devices will get it
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", request.UserId) tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", request.UserId)
var devices []*shared.Device var devices []*shared.Device
checkGormResult(tx.Find(&devices)) checkGormResult(tx.Find(&devices))
if len(devices) == 0 { if len(devices) == 0 {
@ -417,11 +410,11 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices)) fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices { for _, device := range devices {
request.DestinationDeviceId = device.DeviceId request.DestinationDeviceId = device.DeviceId
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&request)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&request))
} }
// Also delete anything currently in the DB matching it // Also delete anything currently in the DB matching it
numDeleted, err := applyDeletionRequestsToBackend(ctx, request) numDeleted, err := applyDeletionRequestsToBackend(r.Context(), request)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -432,7 +425,6 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
} }
func healthCheckHandler(w http.ResponseWriter, r *http.Request) { func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if isProductionEnvironment() { if isProductionEnvironment() {
// Check that we have a reasonable looking set of devices/entries in the DB // Check that we have a reasonable looking set of devices/entries in the DB
rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
@ -444,12 +436,12 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
panic("Suspiciously few enc history entries!") panic("Suspiciously few enc history entries!")
} }
var count int64 var count int64
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&count)) checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&count))
if count < 100 { if count < 100 {
panic("Suspiciously few devices!") panic("Suspiciously few devices!")
} }
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.EncHistoryEntry{ checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.EncHistoryEntry{
EncryptedData: []byte("data"), EncryptedData: []byte("data"),
Nonce: []byte("nonce"), Nonce: []byte("nonce"),
DeviceId: "healthcheck_device_id", DeviceId: "healthcheck_device_id",