Optimize number of round-trip HTTP connections made by the client by having the submit handler return metadata about whether there are pending dump/deletion requests

For now, I'm still keeping the dedicated endpoints for those functionalities, but since most of the time there are no dump/deletion requests this should cut down the number of requests made by the client by 2/3.
This commit is contained in:
David Dworken 2023-09-21 11:35:24 -07:00
parent b05fb0f818
commit 1e43de689f
No known key found for this signature in database
10 changed files with 156 additions and 63 deletions

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"encoding/json"
"fmt"
"html"
@ -27,6 +28,7 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
if len(entries) == 0 {
return
}
userId := entries[0].UserId
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
@ -50,8 +52,32 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
resp := shared.SubmitResponse{
HaveDumpRequests: s.haveDumpRequests(r.Context(), userId, deviceId),
HaveDeletionRequests: s.haveDeletionRequests(r.Context(), userId, deviceId),
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
panic(err)
}
}
func (s *Server) haveDumpRequests(ctx context.Context, userId, deviceId string) bool {
if userId == "" || deviceId == "" {
return true
}
dumpRequests, err := s.db.DumpRequestForUserAndDevice(ctx, userId, deviceId)
checkGormError(err)
return len(dumpRequests) > 0
}
func (s *Server) haveDeletionRequests(ctx context.Context, userId, deviceId string) bool {
if userId == "" || deviceId == "" {
return true
}
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(ctx, userId, deviceId)
checkGormError(err)
return len(deletionRequests) > 0
}
func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
@ -171,8 +197,6 @@ func (s *Server) apiBannerHandler(w http.ResponseWriter, r *http.Request) {
func (s *Server) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err)

View File

@ -29,6 +29,10 @@ var DB *database.DB
const testDBDSN = "file::memory:?_journal_mode=WAL&cache=shared"
func TestMain(m *testing.M) {
// Set env variable
defer testutils.BackupAndRestoreEnv("HISHTORY_TEST")()
os.Setenv("HISHTORY_TEST", "1")
// setup test database
db, err := database.OpenSQLite(testDBDSN, &gorm.Config{})
if err != nil {
@ -73,37 +77,31 @@ func TestESubmitThenQuery(t *testing.T) {
testutils.Check(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w))
// Query for device id 1
w := httptest.NewRecorder()
w = httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
s.apiQueryHandler(w, searchReq)
require.Equal(t, w.Result().StatusCode, 200)
res := w.Result()
defer res.Body.Close()
respBody, err := io.ReadAll(res.Body)
testutils.Check(t, err)
var retrievedEntries []*shared.EncHistoryEntry
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
require.Equal(t, 1, len(retrievedEntries))
dbEntry := retrievedEntries[0]
if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 0 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
require.Equal(t, devId1, dbEntry.DeviceId)
require.Equal(t, data.UserId("key"), dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
testutils.Check(t, err)
if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
}
require.True(t, data.EntryEquals(decEntry, entry))
// Same for device id 2
w = httptest.NewRecorder()
@ -344,8 +342,11 @@ func TestDeletionRequests(t *testing.T) {
testutils.Check(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w))
// And another entry for user1
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
@ -354,8 +355,11 @@ func TestDeletionRequests(t *testing.T) {
testutils.Check(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody))
w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w))
// And an entry for user2 that has the same timestamp as the previous entry
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
@ -365,11 +369,14 @@ func TestDeletionRequests(t *testing.T) {
testutils.Check(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w))
// Query for device id 1
w := httptest.NewRecorder()
w = httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
s.apiQueryHandler(w, searchReq)
res := w.Result()
@ -469,6 +476,17 @@ func TestDeletionRequests(t *testing.T) {
t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry3)
}
// Check that apiSubmit tells the client that there is a pending deletion request
encEntry, err = data.EncryptHistoryEntry("dkey", entry2)
testutils.Check(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody))
w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: true}, deserializeSubmitResponse(t, w))
// Query for deletion requests
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
@ -563,8 +581,11 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
testutils.Check(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, shared.SubmitResponse{HaveDumpRequests: true, HaveDeletionRequests: false}, deserializeSubmitResponse(t, w))
// Call cleanDatabase and just check that there are no panics
testutils.Check(t, DB.Clean(context.TODO()))
@ -580,3 +601,9 @@ func assertNoLeakedConnections(t *testing.T, db *database.DB) {
t.Fatalf("expected DB to have not leak connections, actually have %d", numConns)
}
}
func deserializeSubmitResponse(t *testing.T, w *httptest.ResponseRecorder) shared.SubmitResponse {
submitResponse := shared.SubmitResponse{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &submitResponse))
return submitResponse
}

View File

@ -157,6 +157,7 @@ func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Reque
}
func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
// TODO: Change code like this to use json.NewDecoder for simplicity
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)

View File

@ -82,6 +82,14 @@ func getRequiredQueryParam(r *http.Request, queryParam string) string {
return val
}
func getOptionalQueryParam(r *http.Request, queryParam string, isTestEnvironment bool) string {
val := r.URL.Query().Get(queryParam)
if val == "" && isTestEnvironment {
panic(fmt.Sprintf("request to %s is missing optional query param=%#v that is required in test environments", r.URL, queryParam))
}
return val
}
func checkGormError(err error) {
if err == nil {
return

View File

@ -184,5 +184,3 @@ func main() {
panic(err)
}
}
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
@ -937,11 +938,22 @@ func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.Hist
}
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
resp, err := http.Post("http://localhost:8080/api/v1/submit", "application/json", bytes.NewBuffer(jsonValue))
require.NotEqual(t, "", entry.DeviceId)
resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue))
testutils.Check(t, err)
if resp.StatusCode != 200 {
t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read resp.Body: %v", err)
}
submitResp := shared.SubmitResponse{}
err = json.Unmarshal(respBody, &submitResp)
if err != nil {
t.Fatalf("failed to deserialize SubmitResponse: %v", err)
}
}
func testTimestampsAreReasonablyCorrect(t *testing.T, tester shellTester) {

View File

@ -169,11 +169,21 @@ func saveHistoryEntry(ctx context.Context) {
lib.CheckFatalError(err)
// Persist it remotely
shouldCheckForDeletionRequests := true
shouldCheckForDumpRequests := true
if !config.IsOffline {
jsonValue, err := lib.EncryptAndMarshal(config, []*data.HistoryEntry{entry})
lib.CheckFatalError(err)
_, err = lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
if err != nil {
w, err := lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
if err == nil {
submitResponse := shared.SubmitResponse{}
err := json.Unmarshal(w, &submitResponse)
if err != nil {
lib.CheckFatalError(fmt.Errorf("failed to deserialize response from /api/v1/submit: %w", err))
}
shouldCheckForDeletionRequests = submitResponse.HaveDeletionRequests
shouldCheckForDumpRequests = submitResponse.HaveDumpRequests
} else {
if lib.IsOfflineError(err) {
hctx.GetLogger().Infof("Failed to remotely persist hishtory entry because we failed to connect to the remote server! This is likely because the device is offline, but also could be because the remote server is having reliability issues. Original error: %v", err)
if !config.HaveMissedUploads {
@ -188,38 +198,42 @@ func saveHistoryEntry(ctx context.Context) {
}
// Check if there is a pending dump request and reply to it if so
dumpRequests, err := lib.GetDumpRequests(config)
if err != nil {
if lib.IsOfflineError(err) {
// It is fine to just ignore this, the next command will retry the API and eventually we will respond to any pending dump requests
dumpRequests = []*shared.DumpRequest{}
hctx.GetLogger().Infof("Failed to check for dump requests because we failed to connect to the remote server!")
} else {
lib.CheckFatalError(err)
}
}
if len(dumpRequests) > 0 {
lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx))
entries, err := lib.Search(ctx, db, "", 0)
lib.CheckFatalError(err)
var encEntries []*shared.EncHistoryEntry
for _, entry := range entries {
enc, err := data.EncryptHistoryEntry(config.UserSecret, *entry)
lib.CheckFatalError(err)
encEntries = append(encEntries, &enc)
}
reqBody, err := json.Marshal(encEntries)
lib.CheckFatalError(err)
for _, dumpRequest := range dumpRequests {
if !config.IsOffline {
_, err := lib.ApiPost("/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody)
if shouldCheckForDumpRequests {
dumpRequests, err := lib.GetDumpRequests(config)
if err != nil {
if lib.IsOfflineError(err) {
// It is fine to just ignore this, the next command will retry the API and eventually we will respond to any pending dump requests
dumpRequests = []*shared.DumpRequest{}
hctx.GetLogger().Infof("Failed to check for dump requests because we failed to connect to the remote server!")
} else {
lib.CheckFatalError(err)
}
}
if len(dumpRequests) > 0 {
lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx))
entries, err := lib.Search(ctx, db, "", 0)
lib.CheckFatalError(err)
var encEntries []*shared.EncHistoryEntry
for _, entry := range entries {
enc, err := data.EncryptHistoryEntry(config.UserSecret, *entry)
lib.CheckFatalError(err)
encEntries = append(encEntries, &enc)
}
reqBody, err := json.Marshal(encEntries)
lib.CheckFatalError(err)
for _, dumpRequest := range dumpRequests {
if !config.IsOffline {
_, err := lib.ApiPost("/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody)
lib.CheckFatalError(err)
}
}
}
}
// Handle deletion requests
lib.CheckFatalError(lib.ProcessDeletionRequests(ctx))
if shouldCheckForDeletionRequests {
lib.CheckFatalError(lib.ProcessDeletionRequests(ctx))
}
if config.BetaMode {
db.Commit()

View File

@ -158,6 +158,7 @@ func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (Histo
}
func EntryEquals(entry1, entry2 HistoryEntry) bool {
// TODO: Can we remove this function? Or at least move it to a test-only file?
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
entry1.Command == entry2.Command &&

View File

@ -114,6 +114,13 @@ type Feedback struct {
Feedback string `json:"feedback"`
}
// Response from submitting new history entries. Contains metadata that is used to avoid making additional round-trip
// requests to the hishtory backend.
type SubmitResponse struct {
HaveDumpRequests bool `json:"have_dump_requests"`
HaveDeletionRequests bool `json:"have_deletion_requests"`
}
func Chunks[k any](slice []k, chunkSize int) [][]k {
var chunks [][]k
for i := 0; i < len(slice); i += chunkSize {

View File

@ -276,7 +276,7 @@ func RunTestServer() func() {
panic(fmt.Sprintf("server failed to do something: stderr=%#v, stdout=%#v", stderr.String(), stdout.String()))
}
if strings.Contains(allOutput, "ERROR:") || strings.Contains(allOutput, "http: panic serving") {
panic(fmt.Sprintf("server experienced an error: stderr=%#v, stdout=%#v", stderr.String(), stdout.String()))
panic(fmt.Sprintf("server experienced an error\n\n\nstderr=\n%s\n\n\nstdout=%s", stderr.String(), stdout.String()))
}
// Persist test server logs for debugging
f, err := os.OpenFile("/tmp/hishtory-server.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
@ -325,6 +325,7 @@ func MakeFakeHistoryEntry(command string) data.HistoryEntry {
ExitCode: 2,
StartTime: time.Unix(fakeHistoryTimestamp, 0).UTC(),
EndTime: time.Unix(fakeHistoryTimestamp+3, 0).UTC(),
DeviceId: "fake_device_id",
}
}