Always check gorm interactions for errors

This commit is contained in:
David Dworken 2022-10-02 19:41:00 -07:00
parent 77e078489e
commit b7c64b61c8

View File

@ -11,6 +11,7 @@ import (
"net/http"
"os"
"os/user"
"runtime"
"strconv"
"strings"
"time"
@ -45,8 +46,6 @@ type UsageData struct {
NumQueries int `json:"num_queries"`
}
// TODO: Audit this file for queries that don't check result.Error
func getRequiredQueryParam(r *http.Request, queryParam string) string {
val := r.URL.Query().Get(queryParam)
if val == "" {
@ -124,20 +123,14 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
updateUsageData(r, entry.UserId, entry.DeviceId, 1, false)
tx := GLOBAL_DB.Where("user_id = ?", entry.UserId)
var devices []*shared.Device
result := tx.Find(&devices)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(tx.Find(&devices))
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entry.UserId))
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
for _, device := range devices {
entry.DeviceId = device.DeviceId
result := GLOBAL_DB.Create(&entry)
if result.Error != nil {
panic(result.Error)
}
checkGormResult(GLOBAL_DB.Create(&entry))
}
}
}
@ -148,10 +141,7 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
updateUsageData(r, userId, deviceId, 0, false)
tx := GLOBAL_DB.Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(tx.Find(&historyEntries))
resp, err := json.Marshal(historyEntries)
if err != nil {
panic(err)
@ -164,14 +154,11 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(r, userId, deviceId, 0, true)
// Increment the count
result := GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)
if result.Error != nil {
panic(result.Error)
}
checkGormResult(GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId))
// Delete any entries that match a pending deletion request
var deletionRequests []*shared.DeletionRequest
GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests)
checkGormResult(GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
for _, request := range deletionRequests {
_, err := applyDeletionRequestsToBackend(*request)
if err != nil {
@ -182,10 +169,7 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
// Then retrieve, to avoid a race condition
tx := GLOBAL_DB.Where("device_id = ? AND read_count < 5", deviceId)
var historyEntries []*shared.EncHistoryEntry
result = tx.Find(&historyEntries)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(tx.Find(&historyEntries))
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
resp, err := json.Marshal(historyEntries)
if err != nil {
@ -206,14 +190,11 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var existingDevicesCount int64 = -1
result := GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount)
checkGormResult(GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount))
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
if result.Error != nil {
panic(result.Error)
}
GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()})
checkGormResult(GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}))
if existingDevicesCount > 0 {
GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
checkGormResult(GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}))
}
updateUsageData(r, userId, deviceId, 0, false)
}
@ -223,10 +204,7 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
result := GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
respBody, err := json.Marshal(dumpRequests)
if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
@ -254,20 +232,14 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if entry.UserId != userId {
return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)
}
result := tx.Create(&entry)
if result.Error != nil {
return fmt.Errorf("failed to create entry: %v", err)
}
checkGormResult(tx.Create(&entry))
}
return nil
})
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %v", err))
}
result := GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)
if result.Error != nil {
panic(fmt.Errorf("failed to clear the dump request: %v", err))
}
checkGormResult(GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
updateUsageData(r, userId, srcDeviceId, len(entries), false)
}
@ -284,17 +256,11 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount
result := GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId)
if result.Error != nil {
panic(result.Error)
}
checkGormResult(GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
// Return all the deletion requests
var deletionRequests []*shared.DeletionRequest
result = GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
respBody, err := json.Marshal(deletionRequests)
if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
@ -318,20 +284,14 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
// Store the deletion request so all the devices will get it
tx := GLOBAL_DB.Where("user_id = ?", request.UserId)
var devices []*shared.Device
result := tx.Find(&devices)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
checkGormResult(tx.Find(&devices))
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId))
}
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices {
request.DestinationDeviceId = device.DeviceId
result := GLOBAL_DB.Create(&request)
if result.Error != nil {
panic(result.Error)
}
checkGormResult(GLOBAL_DB.Create(&request))
}
// Also delete anything currently in the DB matching it
@ -348,17 +308,12 @@ func applyDeletionRequestsToBackend(request shared.DeletionRequest) (int, error)
tx = tx.Or(GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
}
result := tx.Delete(&shared.EncHistoryEntry{})
if result.Error != nil {
return 0, result.Error
}
checkGormResult(result)
return int(result.RowsAffected), nil
}
func wipeDbHandler(w http.ResponseWriter, r *http.Request) {
result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries")
if result.Error != nil {
panic(result.Error)
}
checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries"))
}
func isTestEnvironment() bool {
@ -604,14 +559,8 @@ func byteCountToString(b int) string {
}
func cleanDatabase() error {
result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
if result.Error != nil {
return result.Error
}
result = GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100")
if result.Error != nil {
return result.Error
}
checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10"))
checkGormResult(GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100"))
// TODO(optimization): Clean the database by deleting entries for users that haven't been used in X amount of time
return nil
}
@ -659,4 +608,11 @@ func basicAuth(next http.HandlerFunc) http.HandlerFunc {
})
}
func checkGormResult(result *gorm.DB) {
if result.Error != nil {
_, filename, line, _ := runtime.Caller(1)
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, result.Error))
}
}
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?