Follow up to #103: pull context from r.Context() when used rather than at the start of functions

This commit is contained in:
David Dworken 2023-09-07 07:56:03 -07:00
parent 86c0acfbc8
commit 68e3a813c9
No known key found for this signature in database

View File

@ -68,23 +68,23 @@ func getHishtoryVersion(r *http.Request) string {
return r.Header.Get("X-Hishtory-Version")
}
func updateUsageData(ctx context.Context, r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) {
func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) {
var usageData []UsageData
GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
if len(usageData) == 0 {
GLOBAL_DB.WithContext(ctx).Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntriesHandled, Version: getHishtoryVersion(r)})
GLOBAL_DB.WithContext(r.Context()).Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntriesHandled, Version: getHishtoryVersion(r)})
} else {
usage := usageData[0]
GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r))
GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r))
if numEntriesHandled > 0 {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId)
GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId)
}
if usage.Version != getHishtoryVersion(r) {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", getHishtoryVersion(r), userId, deviceId)
GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", getHishtoryVersion(r), userId, deviceId)
}
}
if isQuery {
GLOBAL_DB.WithContext(ctx).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userId, deviceId)
GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userId, deviceId)
}
}
@ -132,24 +132,23 @@ func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
}
func statsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var numDevices int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&numDevices))
type numEntriesProcessed struct {
Total int
}
nep := numEntriesProcessed{}
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep))
var numDbEntries int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries))
lastWeek := time.Now().AddDate(0, 0, -7)
var weeklyActiveInstalls int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls))
var weeklyQueryUsers int64 = 0
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers))
var lastRegistration string = ""
row := GLOBAL_DB.WithContext(ctx).Raw("select to_char(max(registration_date), 'DD Month YYYY HH24:MI') from devices").Row()
row := GLOBAL_DB.WithContext(r.Context()).Raw("select to_char(max(registration_date), 'DD Month YYYY HH24:MI') from devices").Row()
err := row.Scan(&lastRegistration)
if err != nil {
panic(err)
@ -163,7 +162,6 @@ func statsHandler(w http.ResponseWriter, r *http.Request) {
}
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
@ -177,15 +175,15 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
if len(entries) == 0 {
return
}
updateUsageData(ctx, r, entries[0].UserId, entries[0].DeviceId, len(entries), false)
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", entries[0].UserId)
updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false)
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", entries[0].UserId)
var devices []*shared.Device
checkGormResult(tx.Find(&devices))
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
for _, device := range devices {
for _, entry := range entries {
entry.DeviceId = device.DeviceId
@ -209,11 +207,10 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
}
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(ctx, r, userId, deviceId, 0, false)
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", userId)
updateUsageData(r, userId, deviceId, 0, false)
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries))
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
@ -228,20 +225,20 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(ctx, r, userId, deviceId, 0, true)
updateUsageData(r, userId, deviceId, 0, true)
// Delete any entries that match a pending deletion request
var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(ctx).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
for _, request := range deletionRequests {
_, err := applyDeletionRequestsToBackend(ctx, *request)
_, err := applyDeletionRequestsToBackend(r.Context(), *request)
if err != nil {
panic(err)
}
}
// Then retrieve
tx := GLOBAL_DB.WithContext(ctx).Where("device_id = ? AND read_count < 5", deviceId)
tx := GLOBAL_DB.WithContext(r.Context()).Where("device_id = ? AND read_count < 5", deviceId)
var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries))
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
@ -284,9 +281,8 @@ func getRemoteAddr(r *http.Request) string {
}
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
row := GLOBAL_DB.WithContext(r.Context()).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
var numDistinctUsers int64 = 0
err := row.Scan(&numDistinctUsers)
if err != nil {
@ -299,13 +295,13 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var existingDevicesCount int64 = -1
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount))
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}))
if existingDevicesCount > 0 {
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}))
}
updateUsageData(ctx, r, userId, deviceId, 0, false)
updateUsageData(r, userId, deviceId, 0, false)
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
@ -329,7 +325,6 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
}
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
@ -343,7 +338,7 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
err = GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
for _, entry := range entries {
entry.DeviceId = requestingDeviceId
if entry.UserId != userId {
@ -356,8 +351,8 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err))
}
checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false)
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
updateUsageData(r, userId, srcDeviceId, len(entries), false)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
@ -376,16 +371,15 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
}
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount
checkGormResult(GLOBAL_DB.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
// Return all the deletion requests
var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
respBody, err := json.Marshal(deletionRequests)
if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
@ -394,7 +388,6 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
}
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
@ -408,7 +401,7 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
// Store the deletion request so all the devices will get it
tx := GLOBAL_DB.WithContext(ctx).Where("user_id = ?", request.UserId)
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", request.UserId)
var devices []*shared.Device
checkGormResult(tx.Find(&devices))
if len(devices) == 0 {
@ -417,11 +410,11 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices {
request.DestinationDeviceId = device.DeviceId
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&request))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&request))
}
// Also delete anything currently in the DB matching it
numDeleted, err := applyDeletionRequestsToBackend(ctx, request)
numDeleted, err := applyDeletionRequestsToBackend(r.Context(), request)
if err != nil {
panic(err)
}
@ -432,7 +425,6 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
}
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if isProductionEnvironment() {
// Check that we have a reasonable looking set of devices/entries in the DB
rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
@ -444,12 +436,12 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
panic("Suspiciously few enc history entries!")
}
var count int64
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&count))
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&count))
if count < 100 {
panic("Suspiciously few devices!")
}
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(&shared.EncHistoryEntry{
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.EncHistoryEntry{
EncryptedData: []byte("data"),
Nonce: []byte("nonce"),
DeviceId: "healthcheck_device_id",