From 1e43de689fa5ee8fb07862cc007c298389670bdb Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 21 Sep 2023 11:35:24 -0700 Subject: [PATCH] Optimize number of round-trip HTTP connections made by the client by having the submit handler return metadata about whether there are pending dump/deletion requests For now, I'm still keeping the dedicated endpoints for those functionalities, but since most of the time there are no dump/deletion requests this should cut down the number of requests made by the client by 2/3. --- backend/server/internal/server/api.go | 32 +++++++- backend/server/internal/server/server_test.go | 81 ++++++++++++------- backend/server/internal/server/srv.go | 1 + backend/server/internal/server/util.go | 8 ++ backend/server/server.go | 2 - client/client_test.go | 14 +++- client/cmd/saveHistoryEntry.go | 70 +++++++++------- client/data/data.go | 1 + shared/data.go | 7 ++ shared/testutils/testutils.go | 3 +- 10 files changed, 156 insertions(+), 63 deletions(-) diff --git a/backend/server/internal/server/api.go b/backend/server/internal/server/api.go index 424dd7a..b9053b9 100644 --- a/backend/server/internal/server/api.go +++ b/backend/server/internal/server/api.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "fmt" "html" @@ -27,6 +28,7 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { if len(entries) == 0 { return } + userId := entries[0].UserId // TODO: add these to the context in a middleware version := getHishtoryVersion(r) @@ -50,8 +52,32 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) } - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) + deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment) + resp := shared.SubmitResponse{ + HaveDumpRequests: s.haveDumpRequests(r.Context(), userId, deviceId), + HaveDeletionRequests: s.haveDeletionRequests(r.Context(), userId, deviceId), + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + panic(err) + } +} + +func (s *Server) haveDumpRequests(ctx context.Context, userId, deviceId string) bool { + if userId == "" || deviceId == "" { + return true + } + dumpRequests, err := s.db.DumpRequestForUserAndDevice(ctx, userId, deviceId) + checkGormError(err) + return len(dumpRequests) > 0 +} + +func (s *Server) haveDeletionRequests(ctx context.Context, userId, deviceId string) bool { + if userId == "" || deviceId == "" { + return true + } + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(ctx, userId, deviceId) + checkGormError(err) + return len(deletionRequests) > 0 } func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { @@ -171,8 +197,6 @@ func (s *Server) apiBannerHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") - var dumpRequests []*shared.DumpRequest - // Filter out ones requested by the hishtory instance that sent this request dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) checkGormError(err) diff --git a/backend/server/internal/server/server_test.go b/backend/server/internal/server/server_test.go index 858f19d..b5ca045 100644 --- a/backend/server/internal/server/server_test.go +++ b/backend/server/internal/server/server_test.go @@ -29,6 +29,10 @@ var DB *database.DB const testDBDSN = "file::memory:?_journal_mode=WAL&cache=shared" func TestMain(m *testing.M) { + // Set env variable + defer testutils.BackupAndRestoreEnv("HISHTORY_TEST")() + os.Setenv("HISHTORY_TEST", "1") + // setup test database db, err := database.OpenSQLite(testDBDSN, &gorm.Config{}) if err != nil { @@ -73,37 +77,31 @@ func TestESubmitThenQuery(t *testing.T) { testutils.Check(t, err) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - s.apiSubmitHandler(httptest.NewRecorder(), submitReq) + submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) + w := httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w)) // Query for device id 1 - w := httptest.NewRecorder() + w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) s.apiQueryHandler(w, searchReq) + require.Equal(t, w.Result().StatusCode, 200) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) testutils.Check(t, err) var retrievedEntries []*shared.EncHistoryEntry testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 1 { - t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) - } + require.Equal(t, 1, len(retrievedEntries)) dbEntry := retrievedEntries[0] - if dbEntry.DeviceId != devId1 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != data.UserId("key") { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 0 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, devId1, dbEntry.DeviceId) + require.Equal(t, data.UserId("key"), dbEntry.UserId) + require.Equal(t, 0, dbEntry.ReadCount) decEntry, err := data.DecryptHistoryEntry("key", *dbEntry) testutils.Check(t, err) - if !data.EntryEquals(decEntry, entry) { - t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry) - } + require.True(t, data.EntryEquals(decEntry, entry)) // Same for device id 2 w = httptest.NewRecorder() @@ -344,8 +342,11 @@ func TestDeletionRequests(t *testing.T) { testutils.Check(t, err) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - s.apiSubmitHandler(httptest.NewRecorder(), submitReq) + submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) + w := httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w)) // And another entry for user1 entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -354,8 +355,11 @@ func TestDeletionRequests(t *testing.T) { testutils.Check(t, err) reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - s.apiSubmitHandler(httptest.NewRecorder(), submitReq) + submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody)) + w = httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w)) // And an entry for user2 that has the same timestamp as the previous entry entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -365,11 +369,14 @@ func TestDeletionRequests(t *testing.T) { testutils.Check(t, err) reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - s.apiSubmitHandler(httptest.NewRecorder(), submitReq) + submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) + w = httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w)) // Query for device id 1 - w := httptest.NewRecorder() + w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) s.apiQueryHandler(w, searchReq) res := w.Result() @@ -469,6 +476,17 @@ func TestDeletionRequests(t *testing.T) { t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry3) } + // Check that apiSubmit tells the client that there is a pending deletion request + encEntry, err = data.EncryptHistoryEntry("dkey", entry2) + testutils.Check(t, err) + reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) + testutils.Check(t, err) + submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody)) + w = httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: true}, deserializeSubmitResponse(t, w)) + // Query for deletion requests w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) @@ -563,8 +581,11 @@ func TestCleanDatabaseNoErrors(t *testing.T) { testutils.Check(t, err) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - s.apiSubmitHandler(httptest.NewRecorder(), submitReq) + submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) + w := httptest.NewRecorder() + s.apiSubmitHandler(w, submitReq) + require.Equal(t, 200, w.Result().StatusCode) + require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w)) // Call cleanDatabase and just check that there are no panics testutils.Check(t, DB.Clean(context.TODO())) @@ -580,3 +601,9 @@ func assertNoLeakedConnections(t *testing.T, db *database.DB) { t.Fatalf("expected DB to have not leak connections, actually have %d", numConns) } } + +func deserializeSubmitResponse(t *testing.T, w *httptest.ResponseRecorder) shared.SubmitResponse { + submitResponse := shared.SubmitResponse{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &submitResponse)) + return submitResponse +} diff --git a/backend/server/internal/server/srv.go b/backend/server/internal/server/srv.go index 46582a1..e130e13 100644 --- a/backend/server/internal/server/srv.go +++ b/backend/server/internal/server/srv.go @@ -157,6 +157,7 @@ func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Reque } func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + // TODO: Change code like this to use json.NewDecoder for simplicity data, err := io.ReadAll(r.Body) if err != nil { panic(err) diff --git a/backend/server/internal/server/util.go b/backend/server/internal/server/util.go index 9925271..218ac67 100644 --- a/backend/server/internal/server/util.go +++ b/backend/server/internal/server/util.go @@ -82,6 +82,14 @@ func getRequiredQueryParam(r *http.Request, queryParam string) string { return val } +func getOptionalQueryParam(r *http.Request, queryParam string, isTestEnvironment bool) string { + val := r.URL.Query().Get(queryParam) + if val == "" && isTestEnvironment { + panic(fmt.Sprintf("request to %s is missing optional query param=%#v that is required in test environments", r.URL, queryParam)) + } + return val +} + func checkGormError(err error) { if err == nil { return diff --git a/backend/server/server.go b/backend/server/server.go index f48d2c8..9210955 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -184,5 +184,3 @@ func main() { panic(err) } } - -// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? diff --git a/client/client_test.go b/client/client_test.go index 6fa28ec..6543d80 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "os" "os/exec" @@ -937,11 +938,22 @@ func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.Hist } jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) - resp, err := http.Post("http://localhost:8080/api/v1/submit", "application/json", bytes.NewBuffer(jsonValue)) + require.NotEqual(t, "", entry.DeviceId) + resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue)) testutils.Check(t, err) if resp.StatusCode != 200 { t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode) } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read resp.Body: %v", err) + } + submitResp := shared.SubmitResponse{} + err = json.Unmarshal(respBody, &submitResp) + if err != nil { + t.Fatalf("failed to deserialize SubmitResponse: %v", err) + } } func testTimestampsAreReasonablyCorrect(t *testing.T, tester shellTester) { diff --git a/client/cmd/saveHistoryEntry.go b/client/cmd/saveHistoryEntry.go index b2d69ee..2bd0ec2 100644 --- a/client/cmd/saveHistoryEntry.go +++ b/client/cmd/saveHistoryEntry.go @@ -169,11 +169,21 @@ func saveHistoryEntry(ctx context.Context) { lib.CheckFatalError(err) // Persist it remotely + shouldCheckForDeletionRequests := true + shouldCheckForDumpRequests := true if !config.IsOffline { jsonValue, err := lib.EncryptAndMarshal(config, []*data.HistoryEntry{entry}) lib.CheckFatalError(err) - _, err = lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue) - if err != nil { + w, err := lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue) + if err == nil { + submitResponse := shared.SubmitResponse{} + err := json.Unmarshal(w, &submitResponse) + if err != nil { + lib.CheckFatalError(fmt.Errorf("failed to deserialize response from /api/v1/submit: %w", err)) + } + shouldCheckForDeletionRequests = submitResponse.HaveDeletionRequests + shouldCheckForDumpRequests = submitResponse.HaveDumpRequests + } else { if lib.IsOfflineError(err) { hctx.GetLogger().Infof("Failed to remotely persist hishtory entry because we failed to connect to the remote server! This is likely because the device is offline, but also could be because the remote server is having reliability issues. Original error: %v", err) if !config.HaveMissedUploads { @@ -188,38 +198,42 @@ func saveHistoryEntry(ctx context.Context) { } // Check if there is a pending dump request and reply to it if so - dumpRequests, err := lib.GetDumpRequests(config) - if err != nil { - if lib.IsOfflineError(err) { - // It is fine to just ignore this, the next command will retry the API and eventually we will respond to any pending dump requests - dumpRequests = []*shared.DumpRequest{} - hctx.GetLogger().Infof("Failed to check for dump requests because we failed to connect to the remote server!") - } else { - lib.CheckFatalError(err) - } - } - if len(dumpRequests) > 0 { - lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx)) - entries, err := lib.Search(ctx, db, "", 0) - lib.CheckFatalError(err) - var encEntries []*shared.EncHistoryEntry - for _, entry := range entries { - enc, err := data.EncryptHistoryEntry(config.UserSecret, *entry) - lib.CheckFatalError(err) - encEntries = append(encEntries, &enc) - } - reqBody, err := json.Marshal(encEntries) - lib.CheckFatalError(err) - for _, dumpRequest := range dumpRequests { - if !config.IsOffline { - _, err := lib.ApiPost("/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody) + if shouldCheckForDumpRequests { + dumpRequests, err := lib.GetDumpRequests(config) + if err != nil { + if lib.IsOfflineError(err) { + // It is fine to just ignore this, the next command will retry the API and eventually we will respond to any pending dump requests + dumpRequests = []*shared.DumpRequest{} + hctx.GetLogger().Infof("Failed to check for dump requests because we failed to connect to the remote server!") + } else { lib.CheckFatalError(err) } } + if len(dumpRequests) > 0 { + lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx)) + entries, err := lib.Search(ctx, db, "", 0) + lib.CheckFatalError(err) + var encEntries []*shared.EncHistoryEntry + for _, entry := range entries { + enc, err := data.EncryptHistoryEntry(config.UserSecret, *entry) + lib.CheckFatalError(err) + encEntries = append(encEntries, &enc) + } + reqBody, err := json.Marshal(encEntries) + lib.CheckFatalError(err) + for _, dumpRequest := range dumpRequests { + if !config.IsOffline { + _, err := lib.ApiPost("/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody) + lib.CheckFatalError(err) + } + } + } } // Handle deletion requests - lib.CheckFatalError(lib.ProcessDeletionRequests(ctx)) + if shouldCheckForDeletionRequests { + lib.CheckFatalError(lib.ProcessDeletionRequests(ctx)) + } if config.BetaMode { db.Commit() diff --git a/client/data/data.go b/client/data/data.go index 17d3f70..cea1035 100644 --- a/client/data/data.go +++ b/client/data/data.go @@ -158,6 +158,7 @@ func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (Histo } func EntryEquals(entry1, entry2 HistoryEntry) bool { + // TODO: Can we remove this function? Or at least move it to a test-only file? return entry1.LocalUsername == entry2.LocalUsername && entry1.Hostname == entry2.Hostname && entry1.Command == entry2.Command && diff --git a/shared/data.go b/shared/data.go index 08f7132..2a59282 100644 --- a/shared/data.go +++ b/shared/data.go @@ -114,6 +114,13 @@ type Feedback struct { Feedback string `json:"feedback"` } +// Response from submitting new history entries. Contains metadata that is used to avoid making additional round-trip +// requests to the hishtory backend. +type SubmitResponse struct { + HaveDumpRequests bool `json:"have_dump_requests"` + HaveDeletionRequests bool `json:"have_deletion_requests"` +} + func Chunks[k any](slice []k, chunkSize int) [][]k { var chunks [][]k for i := 0; i < len(slice); i += chunkSize { diff --git a/shared/testutils/testutils.go b/shared/testutils/testutils.go index 4ed5def..0e207e8 100644 --- a/shared/testutils/testutils.go +++ b/shared/testutils/testutils.go @@ -276,7 +276,7 @@ func RunTestServer() func() { panic(fmt.Sprintf("server failed to do something: stderr=%#v, stdout=%#v", stderr.String(), stdout.String())) } if strings.Contains(allOutput, "ERROR:") || strings.Contains(allOutput, "http: panic serving") { - panic(fmt.Sprintf("server experienced an error: stderr=%#v, stdout=%#v", stderr.String(), stdout.String())) + panic(fmt.Sprintf("server experienced an error\n\n\nstderr=\n%s\n\n\nstdout=%s", stderr.String(), stdout.String())) } // Persist test server logs for debugging f, err := os.OpenFile("/tmp/hishtory-server.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) @@ -325,6 +325,7 @@ func MakeFakeHistoryEntry(command string) data.HistoryEntry { ExitCode: 2, StartTime: time.Unix(fakeHistoryTimestamp, 0).UTC(), EndTime: time.Unix(fakeHistoryTimestamp+3, 0).UTC(), + DeviceId: "fake_device_id", } }