mirror of
https://github.com/ddworken/hishtory.git
synced 2025-01-11 16:58:47 +01:00
Make query params required rather than having weird undefined behavior
This commit is contained in:
parent
84182ba5c3
commit
0fac3b7286
@ -36,6 +36,14 @@ type UsageData struct {
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
}
|
||||
|
||||
func getRequiredQueryParam(r *http.Request, queryParam string) string {
|
||||
val := r.URL.Query().Get(queryParam)
|
||||
if val == "" {
|
||||
panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam))
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func updateUsageData(userId, deviceId string) {
|
||||
var usageData []UsageData
|
||||
GLOBAL_DB.Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData)
|
||||
@ -80,8 +88,8 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := r.URL.Query().Get("user_id")
|
||||
deviceId := r.URL.Query().Get("device_id")
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
updateUsageData(userId, deviceId)
|
||||
tx := GLOBAL_DB.Where("user_id = ?", userId)
|
||||
var historyEntries []*shared.EncHistoryEntry
|
||||
@ -97,8 +105,8 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := r.URL.Query().Get("user_id")
|
||||
deviceId := r.URL.Query().Get("device_id")
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
updateUsageData(userId, deviceId)
|
||||
// Increment the count
|
||||
GLOBAL_DB.Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId)
|
||||
@ -118,11 +126,9 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
// TODO: Add a helper to get query params and require them since most of these are meant to be mandatory
|
||||
|
||||
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := r.URL.Query().Get("user_id")
|
||||
deviceId := r.URL.Query().Get("device_id")
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
var existingDevicesCount int64 = -1
|
||||
result := GLOBAL_DB.Model(&shared.Device{}).Where("user_id = ?", userId).Count(&existingDevicesCount)
|
||||
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
|
||||
@ -137,8 +143,8 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := r.URL.Query().Get("user_id")
|
||||
deviceId := r.URL.Query().Get("device_id")
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
var dumpRequests []*shared.DumpRequest
|
||||
// Filter out ones requested by the hishtory instance that sent this request
|
||||
result := GLOBAL_DB.Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)
|
||||
@ -153,8 +159,8 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := r.URL.Query().Get("user_id")
|
||||
requestingDeviceId := r.URL.Query().Get("requesting_device_id")
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
|
||||
data, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -188,8 +194,8 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
commitHash := r.URL.Query().Get("commit_hash")
|
||||
deviceId := r.URL.Query().Get("device_id")
|
||||
commitHash := getRequiredQueryParam(r, "commit_hash")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
forcedBanner := r.URL.Query().Get("forced_banner")
|
||||
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
|
||||
w.Write([]byte(forcedBanner))
|
||||
|
@ -135,7 +135,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
|
||||
// Query for dump requests, there should be one for userId
|
||||
w := httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId, nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err := ioutil.ReadAll(res.Body)
|
||||
@ -155,7 +155,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
|
||||
// And one for otherUser
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser, nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
@ -173,20 +173,9 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
t.Fatalf("unexpected user ID")
|
||||
}
|
||||
|
||||
// And none if we query without a user ID
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
shared.Check(t, err)
|
||||
if string(respBody) != "[]" {
|
||||
t.Fatalf("got unexpected respBody: %#v", string(respBody))
|
||||
}
|
||||
|
||||
// And none if we query for a user ID that doesn't exit
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo", nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
@ -197,7 +186,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
|
||||
// And none for a missing user ID
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=", nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
@ -218,9 +207,19 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2, bytes.NewReader(reqBody))
|
||||
apiSubmitDumpHandler(nil, submitReq)
|
||||
|
||||
// Check that the dump request is no longer there for userId
|
||||
// Check that the dump request is no longer there for userId for either device ID
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId, nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
shared.Check(t, err)
|
||||
if string(respBody) != "[]" {
|
||||
t.Fatalf("got unexpected respBody: %#v", string(respBody))
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
// The other user
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
@ -231,7 +230,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
||||
|
||||
// But it is there for the other user
|
||||
w = httptest.NewRecorder()
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser, nil))
|
||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
|
@ -905,7 +905,12 @@ func testRequestAndReceiveDbDump(t *testing.T, tester shellTester) {
|
||||
secretKey := installHishtory(t, tester, "")
|
||||
|
||||
// Confirm there are no pending dump requests
|
||||
resp, err := lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey))
|
||||
config, err := lib.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
deviceId1 := config.DeviceId
|
||||
resp, err := lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending dump requests: %v", err)
|
||||
}
|
||||
@ -944,8 +949,8 @@ echo other`)
|
||||
// Install a new one (with the same secret key but a diff device id)
|
||||
installHishtory(t, tester, secretKey)
|
||||
|
||||
// Confirm there are pending dump requests
|
||||
resp, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey))
|
||||
// Confirm there is now a pending dump requests that the first device should respond to
|
||||
resp, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending dump requests: %v", err)
|
||||
}
|
||||
@ -977,8 +982,8 @@ echo other`)
|
||||
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
|
||||
}
|
||||
|
||||
// Confirm there are no pending dump requests
|
||||
resp, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey))
|
||||
// Confirm there are no pending dump requests for the first device
|
||||
resp, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get pending dump requests: %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user