From 049901098123b95607c9b32c43866698d3c9d4c1 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 28 Apr 2022 10:56:59 -0700 Subject: [PATCH] Remove the 'e' prefix from api endpoints + implement backend APIs for clean loading of all data from other instances --- backend/server/server.go | 108 ++++++++++++++++------- backend/server/server_test.go | 156 ++++++++++++++++++++++++++++++++-- client/lib/lib.go | 4 +- hishtory.go | 6 +- 4 files changed, 230 insertions(+), 44 deletions(-) diff --git a/backend/server/server.go b/backend/server/server.go index b3261f1..cc9a4be 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "os" + "os/user" "strconv" "strings" "time" @@ -45,7 +46,13 @@ func updateUsageData(userId, deviceId string) { } } -func apiESubmitHandler(w http.ResponseWriter, r *http.Request) { +type DumpRequest struct { + UserId string `json:"user_id"` + RequestingDeviceId string `json:"requesting_device_id"` + RequestTime time.Time `json:"request_time"` +} + +func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { data, err := ioutil.ReadAll(r.Body) if err != nil { panic(err) @@ -55,7 +62,7 @@ func apiESubmitHandler(w http.ResponseWriter, r *http.Request) { if err != nil { panic(fmt.Sprintf("body=%#v, err=%v", data, err)) } - fmt.Printf("apiESubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) + fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) for _, entry := range entries { updateUsageData(entry.UserId, entry.DeviceId) tx := GLOBAL_DB.Where("user_id = ?", entry.UserId) @@ -67,7 +74,7 @@ func apiESubmitHandler(w http.ResponseWriter, r *http.Request) { if len(devices) == 0 { panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entry.UserId)) } - fmt.Printf("apiESubmitHandler: Found %d devices\n", len(devices)) + fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) for _, device := range devices { entry.DeviceId = device.DeviceId result := GLOBAL_DB.Create(&entry) @@ -78,7 +85,7 @@ func apiESubmitHandler(w http.ResponseWriter, r *http.Request) { } } -func apiEQueryHandler(w http.ResponseWriter, r *http.Request) { +func apiQueryHandler(w http.ResponseWriter, r *http.Request) { userId := r.URL.Query().Get("user_id") deviceId := r.URL.Query().Get("device_id") updateUsageData(userId, deviceId) @@ -92,7 +99,7 @@ func apiEQueryHandler(w http.ResponseWriter, r *http.Request) { if result.Error != nil { panic(fmt.Errorf("DB query error: %v", result.Error)) } - fmt.Printf("apiEQueryHandler: Found %d entries\n", len(historyEntries)) + fmt.Printf("apiQueryHandler: Found %d entries\n", len(historyEntries)) resp, err := json.Marshal(historyEntries) if err != nil { panic(err) @@ -100,31 +107,63 @@ func apiEQueryHandler(w http.ResponseWriter, r *http.Request) { w.Write(resp) } -// TODO: bootstrap is a janky solution for the initial version of this. Long term, need to support deleting entries from the DB which means replacing bootstrap with a queued message sent to any live instances. -func apiEBootstrapHandler(w http.ResponseWriter, r *http.Request) { - userId := r.URL.Query().Get("user_id") - deviceId := r.URL.Query().Get("device_id") - updateUsageData(userId, deviceId) - tx := GLOBAL_DB.Where("user_id = ?", userId) - var historyEntries []*shared.EncHistoryEntry - result := tx.Find(&historyEntries) - if result.Error != nil { - panic(fmt.Errorf("DB query error: %v", result.Error)) - } - resp, err := json.Marshal(historyEntries) - if err != nil { - panic(err) - } - w.Write(resp) -} - -func apiERegisterHandler(w http.ResponseWriter, r *http.Request) { +func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { userId := r.URL.Query().Get("user_id") deviceId := r.URL.Query().Get("device_id") GLOBAL_DB.Create(&shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: r.RemoteAddr, RegistrationDate: time.Now()}) + GLOBAL_DB.Create(&DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) updateUsageData(userId, deviceId) } +func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := r.URL.Query().Get("user_id") + var dumpRequests []*DumpRequest + result := GLOBAL_DB.Where("user_id = ?", userId).Find(&dumpRequests) + if result.Error != nil { + panic(fmt.Errorf("DB query error: %v", result.Error)) + } + respBody, err := json.Marshal(dumpRequests) + if err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) + } + w.Write(respBody) +} + +func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + userId := r.URL.Query().Get("user_id") + requestingDeviceId := r.URL.Query().Get("requesting_device_id") + data, err := ioutil.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) + err = GLOBAL_DB.Transaction(func(tx *gorm.DB) error { + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId) + } + result := tx.Create(&entry) + if result.Error != nil { + return fmt.Errorf("failed to create entry: %v", err) + } + } + return nil + }) + if err != nil { + panic(fmt.Errorf("failed to execute transaction to add dumped DB: %v", err)) + } + result := GLOBAL_DB.Delete(&DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId) + if result.Error != nil { + panic(fmt.Errorf("failed to clear the dump request: %v", err)) + } +} + func apiBannerHandler(w http.ResponseWriter, r *http.Request) { commitHash := r.URL.Query().Get("commit_hash") deviceId := r.URL.Query().Get("device_id") @@ -134,7 +173,11 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) { } func isTestEnvironment() bool { - return os.Getenv("HISHTORY_TEST") != "" + u, err := user.Current() + if err != nil { + panic(err) + } + return os.Getenv("HISHTORY_TEST") != "" || u.Username == "david" } func OpenDB() (*gorm.DB, error) { @@ -146,6 +189,8 @@ func OpenDB() (*gorm.DB, error) { db.AutoMigrate(&shared.EncHistoryEntry{}) db.AutoMigrate(&shared.Device{}) db.AutoMigrate(&UsageData{}) + db.AutoMigrate(&DumpRequest{}) + db.Exec("PRAGMA journal_mode = WAL") return db, nil } @@ -156,6 +201,7 @@ func OpenDB() (*gorm.DB, error) { db.AutoMigrate(&shared.EncHistoryEntry{}) db.AutoMigrate(&shared.Device{}) db.AutoMigrate(&UsageData{}) + db.AutoMigrate(&DumpRequest{}) return db, nil } @@ -180,6 +226,7 @@ func cron() error { } func runBackgroundJobs() { + time.Sleep(5 * time.Second) for { err := cron() if err != nil { @@ -374,12 +421,13 @@ func cleanDatabase() error { func main() { fmt.Println("Listening on localhost:8080") - http.Handle("/api/v1/esubmit", withLogging(apiESubmitHandler)) - http.Handle("/api/v1/equery", withLogging(apiEQueryHandler)) - http.Handle("/api/v1/ebootstrap", withLogging(apiEBootstrapHandler)) - http.Handle("/api/v1/eregister", withLogging(apiERegisterHandler)) + http.Handle("/api/v1/submit", withLogging(apiSubmitHandler)) + http.Handle("/api/v1/get-dump-requests", withLogging(apiGetPendingDumpRequestsHandler)) + http.Handle("/api/v1/submit-dump", withLogging(apiSubmitDumpHandler)) + http.Handle("/api/v1/query", withLogging(apiQueryHandler)) + http.Handle("/api/v1/register", withLogging(apiRegisterHandler)) http.Handle("/api/v1/banner", withLogging(apiBannerHandler)) - http.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler)) http.Handle("/api/v1/download", withLogging(apiDownloadHandler)) + http.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler)) log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/backend/server/server_test.go b/backend/server/server_test.go index 7e7e1d2..91007d3 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -26,11 +26,11 @@ func TestESubmitThenQuery(t *testing.T) { otherUser := data.UserId("otherkey") otherDev := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiERegisterHandler(nil, deviceReq) + apiRegisterHandler(nil, deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiERegisterHandler(nil, deviceReq) + apiRegisterHandler(nil, deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil) - apiERegisterHandler(nil, deviceReq) + apiRegisterHandler(nil, deviceReq) // Submit a few entries for different devices entry := data.MakeFakeHistoryEntry("ls ~/") @@ -39,12 +39,12 @@ func TestESubmitThenQuery(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) shared.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiESubmitHandler(nil, submitReq) + apiSubmitHandler(nil, submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiEQueryHandler(w, searchReq) + apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := ioutil.ReadAll(res.Body) @@ -73,7 +73,7 @@ func TestESubmitThenQuery(t *testing.T) { // Same for device id 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiEQueryHandler(w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = ioutil.ReadAll(res.Body) @@ -97,19 +97,157 @@ func TestESubmitThenQuery(t *testing.T) { if !data.EntryEquals(decEntry, entry) { t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry) } +} - // Bootstrap handler should return 2 entries, one for each device +func TestDumpRequestAndResponse(t *testing.T) { + // Set up + defer shared.BackupAndRestore(t)() + InitDB() + + // Register a first device for two different users + userId := data.UserId("dkey") + devId1 := uuid.Must(uuid.NewRandom()).String() + otherUser := data.UserId("dOtherkey") + otherDev1 := uuid.Must(uuid.NewRandom()).String() + deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) + apiRegisterHandler(nil, deviceReq) + deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) + apiRegisterHandler(nil, deviceReq) + + // Query for dump requests, there should be one for userId + w := httptest.NewRecorder() + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId, nil)) + res := w.Result() + defer res.Body.Close() + respBody, err := ioutil.ReadAll(res.Body) + shared.Check(t, err) + var dumpRequests []*DumpRequest + shared.Check(t, json.Unmarshal(respBody, &dumpRequests)) + if len(dumpRequests) != 1 { + t.Fatalf("expected one pending dump request, got %#v", dumpRequests) + } + dumpRequest := dumpRequests[0] + if dumpRequest.RequestingDeviceId != devId1 { + t.Fatalf("unexpected device ID") + } + if dumpRequest.UserId != userId { + t.Fatalf("unexpected user ID") + } + + // And one for otherUser w = httptest.NewRecorder() - searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil) - apiEBootstrapHandler(w, searchReq) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser, nil)) res = w.Result() defer res.Body.Close() respBody, err = ioutil.ReadAll(res.Body) shared.Check(t, err) + dumpRequests = make([]*DumpRequest, 0) + shared.Check(t, json.Unmarshal(respBody, &dumpRequests)) + if len(dumpRequests) != 1 { + t.Fatalf("expected one pending dump request, got %#v", dumpRequests) + } + dumpRequest = dumpRequests[0] + if dumpRequest.RequestingDeviceId != otherDev1 { + t.Fatalf("unexpected device ID") + } + if dumpRequest.UserId != otherUser { + t.Fatalf("unexpected user ID") + } + + // And none if we query without a user ID + w = httptest.NewRecorder() + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) + res = w.Result() + defer res.Body.Close() + respBody, err = ioutil.ReadAll(res.Body) + shared.Check(t, err) + if string(respBody) != "[]" { + t.Fatalf("got unexpected respBody: %#v", string(respBody)) + } + + // And none if we query for a user ID that doesn't exit + w = httptest.NewRecorder() + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo", nil)) + res = w.Result() + defer res.Body.Close() + respBody, err = ioutil.ReadAll(res.Body) + shared.Check(t, err) + if string(respBody) != "[]" { + t.Fatalf("got unexpected respBody: %#v", string(respBody)) + } + + // Now submit a dump for userId + entry1Dec := data.MakeFakeHistoryEntry("ls ~/") + entry1, err := data.EncryptHistoryEntry("dkey", entry1Dec) + shared.Check(t, err) + entry2Dec := data.MakeFakeHistoryEntry("aaaaaaƔaaa") + entry2, err := data.EncryptHistoryEntry("dkey", entry1Dec) + shared.Check(t, err) + reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) + shared.Check(t, err) + submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId1, bytes.NewReader(reqBody)) + apiSubmitDumpHandler(nil, submitReq) + + // Check that the dump request is no longer there for userId + w = httptest.NewRecorder() + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId, nil)) + res = w.Result() + defer res.Body.Close() + respBody, err = ioutil.ReadAll(res.Body) + shared.Check(t, err) + if string(respBody) != "[]" { + t.Fatalf("got unexpected respBody: %#v", string(respBody)) + } + + // But it is there for the other user + w = httptest.NewRecorder() + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser, nil)) + res = w.Result() + defer res.Body.Close() + respBody, err = ioutil.ReadAll(res.Body) + shared.Check(t, err) + dumpRequests = make([]*DumpRequest, 0) + shared.Check(t, json.Unmarshal(respBody, &dumpRequests)) + if len(dumpRequests) != 1 { + t.Fatalf("expected one pending dump request, got %#v", dumpRequests) + } + dumpRequest = dumpRequests[0] + if dumpRequest.RequestingDeviceId != otherDev1 { + t.Fatalf("unexpected device ID") + } + if dumpRequest.UserId != otherUser { + t.Fatalf("unexpected user ID") + } + + // And finally, query to ensure that the dumped entries are in the DB + w = httptest.NewRecorder() + searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) + apiQueryHandler(w, searchReq) + res = w.Result() + defer res.Body.Close() + respBody, err = ioutil.ReadAll(res.Body) + shared.Check(t, err) + var retrievedEntries []*shared.EncHistoryEntry shared.Check(t, json.Unmarshal(respBody, &retrievedEntries)) if len(retrievedEntries) != 2 { t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) } + for _, dbEntry := range retrievedEntries { + if dbEntry.DeviceId != devId1 { + t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) + } + if dbEntry.UserId != userId { + t.Fatalf("Response contains an incorrect user ID: %#v", *dbEntry) + } + if dbEntry.ReadCount != 1 { + t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) + } + decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) + shared.Check(t, err) + if !data.EntryEquals(decEntry, entry1Dec) && !data.EntryEquals(decEntry, entry2Dec) { + t.Fatalf("DB data is different than input! \ndb =%#v\nentry1=%#v\nentry2=%#v", *dbEntry, entry1Dec, entry2Dec) + } + } } func TestUpdateReleaseVersion(t *testing.T) { diff --git a/client/lib/lib.go b/client/lib/lib.go index bdea616..43c2de8 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -198,12 +198,12 @@ func Setup(args []string) error { db.Exec("DELETE FROM history_entries") // Bootstrap from remote date - _, err = ApiGet("/api/v1/eregister?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId) + _, err = ApiGet("/api/v1/register?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId) if err != nil { return fmt.Errorf("failed to register device with backend: %v", err) } - respBody, err := ApiGet("/api/v1/ebootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId) + respBody, err := ApiGet("/api/v1/bootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId) if err != nil { return fmt.Errorf("failed to bootstrap device from the backend: %v", err) } diff --git a/hishtory.go b/hishtory.go index 464f462..41efe60 100644 --- a/hishtory.go +++ b/hishtory.go @@ -60,7 +60,7 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error { if err != nil { return err } - respBody, err := lib.ApiGet("/api/v1/equery?device_id=" + config.DeviceId + "&user_id=" + data.UserId(config.UserSecret)) + respBody, err := lib.ApiGet("/api/v1/query?device_id=" + config.DeviceId + "&user_id=" + data.UserId(config.UserSecret)) if err != nil { return err } @@ -131,10 +131,10 @@ func saveHistoryEntry() { encEntry.DeviceId = config.DeviceId jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) lib.CheckFatalError(err) - _, err = lib.ApiPost("/api/v1/esubmit", "application/json", jsonValue) + _, err = lib.ApiPost("/api/v1/submit", "application/json", jsonValue) if err != nil { if strings.Contains(err.Error(), "dial tcp: lookup api.hishtory.dev") { - // TODO: Somehow handle this + // TODO: Somehow handle this and don't completely lose it lib.GetLogger().Printf("Failed to remotely persist hishtory entry because the device is offline!") } else { lib.CheckFatalError(err)