Make query params required rather than having weird undefined behavior

This commit is contained in:
David Dworken
2022-06-04 23:03:05 -07:00
parent 84182ba5c3
commit 0fac3b7286
3 changed files with 47 additions and 37 deletions

View File

@ -36,6 +36,14 @@ type UsageData struct {
LastUsed time.Time `json:"last_used"`
}
func getRequiredQueryParam(r *http.Request, queryParam string) string {
val := r.URL.Query().Get(queryParam)
if val == "" {
panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam))
}
return val
}
func updateUsageData(userId, deviceId string) {
var usageData []UsageData
GLOBAL_DB.Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
@ -80,8 +88,8 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
}
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("user_id")
deviceId := r.URL.Query().Get("device_id")
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(userId, deviceId)
tx := GLOBAL_DB.Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry
@ -97,8 +105,8 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
}
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("user_id")
deviceId := r.URL.Query().Get("device_id")
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(userId, deviceId)
// Increment the count
GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)
@ -118,11 +126,9 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
w.Write(resp)
}
// TODO: Add a helper to get query params and require them since most of these are meant to be mandatory
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("user_id")
deviceId := r.URL.Query().Get("device_id")
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var existingDevicesCount int64 = -1
result := GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount)
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
@ -137,8 +143,8 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
}
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("user_id")
deviceId := r.URL.Query().Get("device_id")
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
result := GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)
@ -153,8 +159,8 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
}
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("user_id")
requestingDeviceId := r.URL.Query().Get("requesting_device_id")
userId := getRequiredQueryParam(r, "user_id")
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
data, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(err)
@ -188,8 +194,8 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
}
func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
commitHash := r.URL.Query().Get("commit_hash")
deviceId := r.URL.Query().Get("device_id")
commitHash := getRequiredQueryParam(r, "commit_hash")
deviceId := getRequiredQueryParam(r, "device_id")
forcedBanner := r.URL.Query().Get("forced_banner")
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
w.Write([]byte(forcedBanner))