diff --git a/backend/server/server.go b/backend/server/server.go index cf7224c..7dc8454 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "os/user" + "runtime" "strconv" "strings" "time" @@ -45,8 +46,6 @@ type UsageData struct { NumQueries int `json:"num_queries"` } -// TODO: Audit this file for queries that don't check result.Error - func getRequiredQueryParam(r *http.Request, queryParam string) string { val := r.URL.Query().Get(queryParam) if val == "" { @@ -124,20 +123,14 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { updateUsageData(r, entry.UserId, entry.DeviceId, 1, false) tx := GLOBAL_DB.Where("user_id = ?", entry.UserId) var devices []*shared.Device - result := tx.Find(&devices) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(tx.Find(&devices)) if len(devices) == 0 { panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entry.UserId)) } fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) for _, device := range devices { entry.DeviceId = device.DeviceId - result := GLOBAL_DB.Create(&entry) - if result.Error != nil { - panic(result.Error) - } + checkGormResult(GLOBAL_DB.Create(&entry)) } } } @@ -148,10 +141,7 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { updateUsageData(r, userId, deviceId, 0, false) tx := GLOBAL_DB.Where("user_id = ?", userId) var historyEntries []*shared.EncHistoryEntry - result := tx.Find(&historyEntries) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(tx.Find(&historyEntries)) resp, err := json.Marshal(historyEntries) if err != nil { panic(err) @@ -164,14 +154,11 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) { deviceId := getRequiredQueryParam(r, "device_id") updateUsageData(r, userId, deviceId, 0, true) // Increment the count - result := GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId) - if result.Error != nil { - panic(result.Error) - } + checkGormResult(GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)) // Delete any entries that match a pending deletion request var deletionRequests []*shared.DeletionRequest - GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests) + checkGormResult(GLOBAL_DB.Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests)) for _, request := range deletionRequests { _, err := applyDeletionRequestsToBackend(*request) if err != nil { @@ -182,10 +169,7 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) { // Then retrieve, to avoid a race condition tx := GLOBAL_DB.Where("device_id = ? AND read_count < 5", deviceId) var historyEntries []*shared.EncHistoryEntry - result = tx.Find(&historyEntries) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(tx.Find(&historyEntries)) fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) resp, err := json.Marshal(historyEntries) if err != nil { @@ -206,14 +190,11 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") var existingDevicesCount int64 = -1 - result := GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount) + checkGormResult(GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount)) fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if result.Error != nil { - panic(result.Error) - } - GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}) + checkGormResult(GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()})) if existingDevicesCount > 0 { - GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) + checkGormResult(GLOBAL_DB.Create(&shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})) } updateUsageData(r, userId, deviceId, 0, false) } @@ -223,10 +204,7 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { deviceId := getRequiredQueryParam(r, "device_id") var dumpRequests []*shared.DumpRequest // Filter out ones requested by the hishtory instance that sent this request - result := GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)) respBody, err := json.Marshal(dumpRequests) if err != nil { panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) @@ -254,20 +232,14 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { if entry.UserId != userId { return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId) } - result := tx.Create(&entry) - if result.Error != nil { - return fmt.Errorf("failed to create entry: %v", err) - } + checkGormResult(tx.Create(&entry)) } return nil }) if err != nil { panic(fmt.Errorf("failed to execute transaction to add dumped DB: %v", err)) } - result := GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId) - if result.Error != nil { - panic(fmt.Errorf("failed to clear the dump request: %v", err)) - } + checkGormResult(GLOBAL_DB.Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)) updateUsageData(r, userId, srcDeviceId, len(entries), false) } @@ -284,17 +256,11 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { deviceId := getRequiredQueryParam(r, "device_id") // Increment the ReadCount - result := GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId) - if result.Error != nil { - panic(result.Error) - } + checkGormResult(GLOBAL_DB.Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId)) // Return all the deletion requests var deletionRequests []*shared.DeletionRequest - result = GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests)) respBody, err := json.Marshal(deletionRequests) if err != nil { panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) @@ -318,20 +284,14 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { // Store the deletion request so all the devices will get it tx := GLOBAL_DB.Where("user_id = ?", request.UserId) var devices []*shared.Device - result := tx.Find(&devices) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } + checkGormResult(tx.Find(&devices)) if len(devices) == 0 { panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId)) } fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices)) for _, device := range devices { request.DestinationDeviceId = device.DeviceId - result := GLOBAL_DB.Create(&request) - if result.Error != nil { - panic(result.Error) - } + checkGormResult(GLOBAL_DB.Create(&request)) } // Also delete anything currently in the DB matching it @@ -348,17 +308,12 @@ func applyDeletionRequestsToBackend(request shared.DeletionRequest) (int, error) tx = tx.Or(GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date)) } result := tx.Delete(&shared.EncHistoryEntry{}) - if result.Error != nil { - return 0, result.Error - } + checkGormResult(result) return int(result.RowsAffected), nil } func wipeDbHandler(w http.ResponseWriter, r *http.Request) { - result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries") - if result.Error != nil { - panic(result.Error) - } + checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) } func isTestEnvironment() bool { @@ -604,14 +559,8 @@ func byteCountToString(b int) string { } func cleanDatabase() error { - result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10") - if result.Error != nil { - return result.Error - } - result = GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100") - if result.Error != nil { - return result.Error - } + checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries WHERE read_count > 10")) + checkGormResult(GLOBAL_DB.Exec("DELETE FROM deletion_requests WHERE read_count > 100")) // TODO(optimization): Clean the database by deleting entries for users that haven't been used in X amount of time return nil } @@ -659,4 +608,11 @@ func basicAuth(next http.HandlerFunc) http.HandlerFunc { }) } +func checkGormResult(result *gorm.DB) { + if result.Error != nil { + _, filename, line, _ := runtime.Caller(1) + panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, result.Error)) + } +} + // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?