Always check gorm interactions for errors

This commit is contained in:
David Dworken 2022-10-02 19:41:00 -07:00
parent 77e078489e
commit b7c64b61c8

View File

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/user" "os/user"
"runtime"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -45,8 +46,6 @@ type UsageData struct {
NumQueries int `json:"num_queries"` NumQueries int `json:"num_queries"`
} }
// TODO: Audit this file for queries that don't check result.Error
func getRequiredQueryParam(r *http.Request, queryParam string) string { func getRequiredQueryParam(r *http.Request, queryParam string) string {
val := r.URL.Query().Get(queryParam) val := r.URL.Query().Get(queryParam)
if val == "" { if val == "" {
@ -124,20 +123,14 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
updateUsageData(r, entry.UserId, entry.DeviceId, 1, false) updateUsageData(r, entry.UserId, entry.DeviceId, 1, false)
tx := GLOBAL_DB.Where("user_id = ?", entry.UserId) tx := GLOBAL_DB.Where("user_id = ?", entry.UserId)
var devices []*shared.Device var devices []*shared.Device
result := tx.Find(&devices) checkGormResult(tx.Find(&devices))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
if len(devices) == 0 { if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entry.UserId)) panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entry.UserId))
} }
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
for _, device := range devices { for _, device := range devices {
entry.DeviceId = device.DeviceId entry.DeviceId = device.DeviceId
result := GLOBAL_DB.Create(&entry) checkGormResult(GLOBAL_DB.Create(&entry))
if result.Error != nil {
panic(result.Error)
}
} }
} }
} }
@ -148,10 +141,7 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
updateUsageData(r, userId, deviceId, 0, false) updateUsageData(r, userId, deviceId, 0, false)
tx := GLOBAL_DB.Where("user_id = ?", userId) tx := GLOBAL_DB.Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
result := tx.Find(&historyEntries) checkGormResult(tx.Find(&historyEntries))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
resp, err := json.Marshal(historyEntries) resp, err := json.Marshal(historyEntries)
if err != nil { if err != nil {
panic(err) panic(err)
@ -164,14 +154,11 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(r, userId, deviceId, 0, true) updateUsageData(r, userId, deviceId, 0, true)
// Increment the count // Increment the count
result := GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId) checkGormResult(GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId))
if result.Error != nil {
panic(result.Error)
}
// Delete any entries that match a pending deletion request // Delete any entries that match a pending deletion request
var deletionRequests []*shared.DeletionRequest var deletionRequests []*shared.DeletionRequest
GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests) checkGormResult(GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
for _, request := range deletionRequests { for _, request := range deletionRequests {
_, err := applyDeletionRequestsToBackend(*request) _, err := applyDeletionRequestsToBackend(*request)
if err != nil { if err != nil {
@ -182,10 +169,7 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
// Then retrieve, to avoid a race condition // Then retrieve, to avoid a race condition
tx := GLOBAL_DB.Where("device_id = ? AND read_count < 5", deviceId) tx := GLOBAL_DB.Where("device_id = ? AND read_count < 5", deviceId)
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
result = tx.Find(&historyEntries) checkGormResult(tx.Find(&historyEntries))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
resp, err := json.Marshal(historyEntries) resp, err := json.Marshal(historyEntries)
if err != nil { if err != nil {
@ -206,14 +190,11 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id") userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
var existingDevicesCount int64 = -1 var existingDevicesCount int64 = -1
result := GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount) checkGormResult(GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount))
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
if result.Error != nil { checkGormResult(GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}))
panic(result.Error)
}
GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()})
if existingDevicesCount > 0 { if existingDevicesCount > 0 {
GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) checkGormResult(GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}))
} }
updateUsageData(r, userId, deviceId, 0, false) updateUsageData(r, userId, deviceId, 0, false)
} }
@ -223,10 +204,7 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request // Filter out ones requested by the hishtory instance that sent this request
result := GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests) checkGormResult(GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
respBody, err := json.Marshal(dumpRequests) respBody, err := json.Marshal(dumpRequests)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
@ -254,20 +232,14 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if entry.UserId != userId { if entry.UserId != userId {
return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId) return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)
} }
result := tx.Create(&entry) checkGormResult(tx.Create(&entry))
if result.Error != nil {
return fmt.Errorf("failed to create entry: %v", err)
}
} }
return nil return nil
}) })
if err != nil { if err != nil {
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %v", err)) panic(fmt.Errorf("failed to execute transaction to add dumped DB: %v", err))
} }
result := GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId) checkGormResult(GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
if result.Error != nil {
panic(fmt.Errorf("failed to clear the dump request: %v", err))
}
updateUsageData(r, userId, srcDeviceId, len(entries), false) updateUsageData(r, userId, srcDeviceId, len(entries), false)
} }
@ -284,17 +256,11 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id") deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount // Increment the ReadCount
result := GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId) checkGormResult(GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
if result.Error != nil {
panic(result.Error)
}
// Return all the deletion requests // Return all the deletion requests
var deletionRequests []*shared.DeletionRequest var deletionRequests []*shared.DeletionRequest
result = GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests) checkGormResult(GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
respBody, err := json.Marshal(deletionRequests) respBody, err := json.Marshal(deletionRequests)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
@ -318,20 +284,14 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
// Store the deletion request so all the devices will get it // Store the deletion request so all the devices will get it
tx := GLOBAL_DB.Where("user_id = ?", request.UserId) tx := GLOBAL_DB.Where("user_id = ?", request.UserId)
var devices []*shared.Device var devices []*shared.Device
result := tx.Find(&devices) checkGormResult(tx.Find(&devices))
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
if len(devices) == 0 { if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId)) panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId))
} }
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices)) fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices { for _, device := range devices {
request.DestinationDeviceId = device.DeviceId request.DestinationDeviceId = device.DeviceId
result := GLOBAL_DB.Create(&request) checkGormResult(GLOBAL_DB.Create(&request))
if result.Error != nil {
panic(result.Error)
}
} }
// Also delete anything currently in the DB matching it // Also delete anything currently in the DB matching it
@ -348,17 +308,12 @@ func applyDeletionRequestsToBackend(request shared.DeletionRequest) (int, error)
tx = tx.Or(GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date)) tx = tx.Or(GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
} }
result := tx.Delete(&shared.EncHistoryEntry{}) result := tx.Delete(&shared.EncHistoryEntry{})
if result.Error != nil { checkGormResult(result)
return 0, result.Error
}
return int(result.RowsAffected), nil return int(result.RowsAffected), nil
} }
func wipeDbHandler(w http.ResponseWriter, r *http.Request) { func wipeDbHandler(w http.ResponseWriter, r *http.Request) {
result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries") checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries"))
if result.Error != nil {
panic(result.Error)
}
} }
func isTestEnvironment() bool { func isTestEnvironment() bool {
@ -604,14 +559,8 @@ func byteCountToString(b int) string {
} }
func cleanDatabase() error { func cleanDatabase() error {
result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10") checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10"))
if result.Error != nil { checkGormResult(GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100"))
return result.Error
}
result = GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100")
if result.Error != nil {
return result.Error
}
// TODO(optimization): Clean the database by deleting entries for users that haven't been used in X amount of time // TODO(optimization): Clean the database by deleting entries for users that haven't been used in X amount of time
return nil return nil
} }
@ -659,4 +608,11 @@ func basicAuth(next http.HandlerFunc) http.HandlerFunc {
}) })
} }
func checkGormResult(result *gorm.DB) {
if result.Error != nil {
_, filename, line, _ := runtime.Caller(1)
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, result.Error))
}
}
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?