mirror of
https://github.com/ddworken/hishtory.git
synced 2025-02-02 19:49:33 +01:00
* Fix double-syncing error where devices receive entries from themselves * Fix incorrect error message * Add TODO * Update TestESubmitThenQuery after making query more efficient * Update TestDeletionRequests and remove unnecessary asserts * Swap server_test.go to using require * Fix incorrect require due to typo
This commit is contained in:
parent
21c7f5e0db
commit
60cbb1976c
@ -31,7 +31,7 @@ func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) ([]*s
|
|||||||
|
|
||||||
func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
|
func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
|
||||||
var historyEntries []*shared.EncHistoryEntry
|
var historyEntries []*shared.EncHistoryEntry
|
||||||
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries)
|
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ? AND NOT is_from_same_device", deviceID, limit).Find(&historyEntries)
|
||||||
|
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||||
@ -52,12 +52,13 @@ func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHisto
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, devices []*Device, entries []*shared.EncHistoryEntry) error {
|
func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId string, devices []*Device, entries []*shared.EncHistoryEntry) error {
|
||||||
chunkSize := 1000
|
chunkSize := 1000
|
||||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
for _, device := range devices {
|
for _, device := range devices {
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
entry.DeviceId = device.DeviceId
|
entry.DeviceId = device.DeviceId
|
||||||
|
entry.IsFromSameDevice = sourceDeviceId == device.DeviceId
|
||||||
}
|
}
|
||||||
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
|
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
|
||||||
for _, entriesChunk := range shared.Chunks(entries, chunkSize) {
|
for _, entriesChunk := range shared.Chunks(entries, chunkSize) {
|
||||||
|
@ -39,7 +39,8 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
|
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
|
||||||
|
|
||||||
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), devices, entries)
|
sourceDeviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
|
||||||
|
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), sourceDeviceId, devices, entries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
|
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
|
||||||
}
|
}
|
||||||
@ -49,21 +50,20 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
resp := shared.SubmitResponse{}
|
resp := shared.SubmitResponse{}
|
||||||
|
|
||||||
deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
|
if sourceDeviceId != "" {
|
||||||
if deviceId != "" {
|
|
||||||
hv, err := shared.ParseVersionString(version)
|
hv, err := shared.ParseVersionString(version)
|
||||||
if err != nil || hv.GreaterThan(shared.ParsedVersion{MinorVersion: 0, MajorVersion: 221}) {
|
if err != nil || hv.GreaterThan(shared.ParsedVersion{MinorVersion: 0, MajorVersion: 221}) {
|
||||||
// Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary
|
// Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary
|
||||||
// since tests run with v0.Unknown which obviously fails to parse.
|
// since tests run with v0.Unknown which obviously fails to parse.
|
||||||
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
|
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, sourceDeviceId)
|
||||||
checkGormError(err)
|
checkGormError(err)
|
||||||
resp.DumpRequests = dumpRequests
|
resp.DumpRequests = dumpRequests
|
||||||
|
|
||||||
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, sourceDeviceId)
|
||||||
checkGormError(err)
|
checkGormError(err)
|
||||||
resp.DeletionRequests = deletionRequests
|
resp.DeletionRequests = deletionRequests
|
||||||
|
|
||||||
checkGormError(s.db.DeletionRequestInc(r.Context(), userId, deviceId))
|
checkGormError(s.db.DeletionRequestInc(r.Context(), userId, sourceDeviceId))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,6 +73,7 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// TODO: Update this to filter out duplicate entries
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
version := getHishtoryVersion(r)
|
version := getHishtoryVersion(r)
|
||||||
|
@ -72,7 +72,7 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
|
||||||
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
|
|
||||||
// Submit a few entries for different devices
|
// Submit an entry from device 1
|
||||||
entry := testutils.MakeFakeHistoryEntry("ls ~/")
|
entry := testutils.MakeFakeHistoryEntry("ls ~/")
|
||||||
encEntry, err := data.EncryptHistoryEntry("key", entry)
|
encEntry, err := data.EncryptHistoryEntry("key", entry)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -85,7 +85,7 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
||||||
|
|
||||||
// Query for device id 1
|
// Query for device id 1, no results returned
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
s.apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
@ -96,16 +96,9 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var retrievedEntries []*shared.EncHistoryEntry
|
var retrievedEntries []*shared.EncHistoryEntry
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
require.Equal(t, 1, len(retrievedEntries))
|
require.Equal(t, 0, len(retrievedEntries))
|
||||||
dbEntry := retrievedEntries[0]
|
|
||||||
require.Equal(t, devId1, dbEntry.DeviceId)
|
|
||||||
require.Equal(t, data.UserId("key"), dbEntry.UserId)
|
|
||||||
require.Equal(t, 0, dbEntry.ReadCount)
|
|
||||||
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, decEntry, entry)
|
|
||||||
|
|
||||||
// Same for device id 2
|
// Query for device id 2 and the entry is found
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
s.apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
@ -114,20 +107,12 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 1 {
|
require.Len(t, retrievedEntries, 1)
|
||||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
dbEntry := retrievedEntries[0]
|
||||||
}
|
require.Equal(t, dbEntry.DeviceId, devId2)
|
||||||
dbEntry = retrievedEntries[0]
|
require.Equal(t, dbEntry.UserId, data.UserId("key"))
|
||||||
if dbEntry.DeviceId != devId2 {
|
require.Equal(t, 0, dbEntry.ReadCount)
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
|
||||||
}
|
|
||||||
if dbEntry.UserId != data.UserId("key") {
|
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
|
||||||
}
|
|
||||||
if dbEntry.ReadCount != 0 {
|
|
||||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
|
||||||
}
|
|
||||||
decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, decEntry, entry)
|
require.Equal(t, decEntry, entry)
|
||||||
|
|
||||||
@ -140,9 +125,7 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 2 {
|
require.Len(t, retrievedEntries, 2)
|
||||||
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
@ -177,16 +160,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var dumpRequests []*shared.DumpRequest
|
var dumpRequests []*shared.DumpRequest
|
||||||
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
||||||
if len(dumpRequests) != 1 {
|
require.Len(t, dumpRequests, 1)
|
||||||
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
|
|
||||||
}
|
|
||||||
dumpRequest := dumpRequests[0]
|
dumpRequest := dumpRequests[0]
|
||||||
if dumpRequest.RequestingDeviceId != devId2 {
|
require.Equal(t, devId2, dumpRequest.RequestingDeviceId)
|
||||||
t.Fatalf("unexpected device ID")
|
require.Equal(t, userId, dumpRequest.UserId)
|
||||||
}
|
|
||||||
if dumpRequest.UserId != userId {
|
|
||||||
t.Fatalf("unexpected user ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
// And one for otherUser
|
// And one for otherUser
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@ -197,16 +174,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dumpRequests = make([]*shared.DumpRequest, 0)
|
dumpRequests = make([]*shared.DumpRequest, 0)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
||||||
if len(dumpRequests) != 1 {
|
require.Len(t, dumpRequests, 1)
|
||||||
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
|
|
||||||
}
|
|
||||||
dumpRequest = dumpRequests[0]
|
dumpRequest = dumpRequests[0]
|
||||||
if dumpRequest.RequestingDeviceId != otherDev2 {
|
require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId)
|
||||||
t.Fatalf("unexpected device ID")
|
require.Equal(t, otherUser, dumpRequest.UserId)
|
||||||
}
|
|
||||||
if dumpRequest.UserId != otherUser {
|
|
||||||
t.Fatalf("unexpected user ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
// And none if we query for a user ID that doesn't exit
|
// And none if we query for a user ID that doesn't exit
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@ -270,16 +241,10 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dumpRequests = make([]*shared.DumpRequest, 0)
|
dumpRequests = make([]*shared.DumpRequest, 0)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
|
||||||
if len(dumpRequests) != 1 {
|
require.Len(t, dumpRequests, 1)
|
||||||
t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
|
|
||||||
}
|
|
||||||
dumpRequest = dumpRequests[0]
|
dumpRequest = dumpRequests[0]
|
||||||
if dumpRequest.RequestingDeviceId != otherDev2 {
|
require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId)
|
||||||
t.Fatalf("unexpected device ID")
|
require.Equal(t, otherUser, dumpRequest.UserId)
|
||||||
}
|
|
||||||
if dumpRequest.UserId != otherUser {
|
|
||||||
t.Fatalf("unexpected user ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
// And finally, query to ensure that the dumped entries are in the DB
|
// And finally, query to ensure that the dumped entries are in the DB
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@ -291,19 +256,11 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var retrievedEntries []*shared.EncHistoryEntry
|
var retrievedEntries []*shared.EncHistoryEntry
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 2 {
|
require.Len(t, retrievedEntries, 2)
|
||||||
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
|
|
||||||
}
|
|
||||||
for _, dbEntry := range retrievedEntries {
|
for _, dbEntry := range retrievedEntries {
|
||||||
if dbEntry.DeviceId != devId2 {
|
require.Equal(t, devId2, dbEntry.DeviceId)
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
require.Equal(t, userId, dbEntry.UserId)
|
||||||
}
|
require.Equal(t, 0, dbEntry.ReadCount)
|
||||||
if dbEntry.UserId != userId {
|
|
||||||
t.Fatalf("Response contains an incorrect user ID: %#v", *dbEntry)
|
|
||||||
}
|
|
||||||
if dbEntry.ReadCount != 0 {
|
|
||||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
|
||||||
}
|
|
||||||
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, assert.ObjectsAreEqual(decEntry, entry1Dec) || assert.ObjectsAreEqual(decEntry, entry2Dec))
|
require.True(t, assert.ObjectsAreEqual(decEntry, entry1Dec) || assert.ObjectsAreEqual(decEntry, entry2Dec))
|
||||||
@ -345,7 +302,6 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
s.apiSubmitHandler(w, submitReq)
|
s.apiSubmitHandler(w, submitReq)
|
||||||
require.Equal(t, 200, w.Result().StatusCode)
|
require.Equal(t, 200, w.Result().StatusCode)
|
||||||
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
|
||||||
|
|
||||||
// And another entry for user1
|
// And another entry for user1
|
||||||
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
||||||
@ -359,7 +315,6 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
s.apiSubmitHandler(w, submitReq)
|
s.apiSubmitHandler(w, submitReq)
|
||||||
require.Equal(t, 200, w.Result().StatusCode)
|
require.Equal(t, 200, w.Result().StatusCode)
|
||||||
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
|
||||||
|
|
||||||
// And an entry for user2 that has the same timestamp as the previous entry
|
// And an entry for user2 that has the same timestamp as the previous entry
|
||||||
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
||||||
@ -374,7 +329,6 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
s.apiSubmitHandler(w, submitReq)
|
s.apiSubmitHandler(w, submitReq)
|
||||||
require.Equal(t, 200, w.Result().StatusCode)
|
require.Equal(t, 200, w.Result().StatusCode)
|
||||||
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
require.Empty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
|
||||||
|
|
||||||
// Query for device id 1
|
// Query for device id 1
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@ -386,19 +340,11 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var retrievedEntries []*shared.EncHistoryEntry
|
var retrievedEntries []*shared.EncHistoryEntry
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 2 {
|
require.Len(t, retrievedEntries, 1)
|
||||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
|
||||||
}
|
|
||||||
for _, dbEntry := range retrievedEntries {
|
for _, dbEntry := range retrievedEntries {
|
||||||
if dbEntry.DeviceId != devId1 {
|
require.Equal(t, devId1, dbEntry.DeviceId)
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
require.Equal(t, data.UserId("dkey"), dbEntry.UserId)
|
||||||
}
|
require.Equal(t, 0, dbEntry.ReadCount)
|
||||||
if dbEntry.UserId != data.UserId("dkey") {
|
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
|
||||||
}
|
|
||||||
if dbEntry.ReadCount != 0 {
|
|
||||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
|
||||||
}
|
|
||||||
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, assert.ObjectsAreEqual(decEntry, entry1) || assert.ObjectsAreEqual(decEntry, entry2))
|
require.True(t, assert.ObjectsAreEqual(decEntry, entry1) || assert.ObjectsAreEqual(decEntry, entry2))
|
||||||
@ -428,19 +374,11 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 1 {
|
require.Len(t, retrievedEntries, 1)
|
||||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
|
||||||
}
|
|
||||||
dbEntry := retrievedEntries[0]
|
dbEntry := retrievedEntries[0]
|
||||||
if dbEntry.DeviceId != devId1 {
|
require.Equal(t, devId1, dbEntry.DeviceId)
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
require.Equal(t, data.UserId("dkey"), dbEntry.UserId)
|
||||||
}
|
require.Equal(t, 1, dbEntry.ReadCount)
|
||||||
if dbEntry.UserId != data.UserId("dkey") {
|
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
|
||||||
}
|
|
||||||
if dbEntry.ReadCount != 1 {
|
|
||||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
|
||||||
}
|
|
||||||
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, decEntry, entry2)
|
require.Equal(t, decEntry, entry2)
|
||||||
@ -454,19 +392,11 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||||
if len(retrievedEntries) != 1 {
|
require.Len(t, retrievedEntries, 1)
|
||||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
|
||||||
}
|
|
||||||
dbEntry = retrievedEntries[0]
|
dbEntry = retrievedEntries[0]
|
||||||
if dbEntry.DeviceId != otherDev1 {
|
require.Equal(t, otherDev1, dbEntry.DeviceId)
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
require.Equal(t, data.UserId("dOtherkey"), dbEntry.UserId)
|
||||||
}
|
require.Equal(t, 0, dbEntry.ReadCount)
|
||||||
if dbEntry.UserId != data.UserId("dOtherkey") {
|
|
||||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
|
||||||
}
|
|
||||||
if dbEntry.ReadCount != 0 {
|
|
||||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
|
||||||
}
|
|
||||||
decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry)
|
decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, decEntry, entry3)
|
require.Equal(t, decEntry, entry3)
|
||||||
@ -481,7 +411,6 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
s.apiSubmitHandler(w, submitReq)
|
s.apiSubmitHandler(w, submitReq)
|
||||||
require.Equal(t, 200, w.Result().StatusCode)
|
require.Equal(t, 200, w.Result().StatusCode)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
require.NotEmpty(t, deserializeSubmitResponse(t, w).DeletionRequests)
|
||||||
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
|
|
||||||
|
|
||||||
// Query for deletion requests
|
// Query for deletion requests
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@ -493,9 +422,7 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var deletionRequests []*shared.DeletionRequest
|
var deletionRequests []*shared.DeletionRequest
|
||||||
require.NoError(t, json.Unmarshal(respBody, &deletionRequests))
|
require.NoError(t, json.Unmarshal(respBody, &deletionRequests))
|
||||||
if len(deletionRequests) != 1 {
|
require.Len(t, deletionRequests, 1)
|
||||||
t.Fatalf("received %d deletion requests, expected only one", len(deletionRequests))
|
|
||||||
}
|
|
||||||
deletionRequest := deletionRequests[0]
|
deletionRequest := deletionRequests[0]
|
||||||
expected := shared.DeletionRequest{
|
expected := shared.DeletionRequest{
|
||||||
UserId: data.UserId("dkey"),
|
UserId: data.UserId("dkey"),
|
||||||
@ -518,16 +445,12 @@ func TestHealthcheck(t *testing.T) {
|
|||||||
s := NewServer(DB, TrackUsageData(true))
|
s := NewServer(DB, TrackUsageData(true))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
if w.Code != 200 {
|
require.Equal(t, 200, w.Code)
|
||||||
t.Fatalf("expected 200 resp code for healthCheckHandler")
|
|
||||||
}
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err := io.ReadAll(res.Body)
|
respBody, err := io.ReadAll(res.Body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if string(respBody) != "OK" {
|
require.Equal(t, "OK", string(respBody))
|
||||||
t.Fatalf("expected healthcheckHandler to return OK")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
|
@ -2996,7 +2996,9 @@ func BenchmarkQuery(b *testing.B) {
|
|||||||
// Benchmarked code:
|
// Benchmarked code:
|
||||||
b.StartTimer()
|
b.StartTimer()
|
||||||
ctx := hctx.MakeContext()
|
ctx := hctx.MakeContext()
|
||||||
_, err := lib.Search(ctx, hctx.GetDb(ctx), "echo", 100)
|
err := lib.RetrieveAdditionalEntriesFromRemote(ctx, "tui")
|
||||||
|
require.NoError(b, err)
|
||||||
|
_, err = lib.Search(ctx, hctx.GetDb(ctx), "echo", 100)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
}
|
}
|
||||||
|
@ -481,7 +481,7 @@ func ApiGet(ctx context.Context, path string) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("failed to read response body from GET %s%s: %w", GetServerHostname(), path, err)
|
return nil, fmt.Errorf("failed to read response body from GET %s%s: %w", GetServerHostname(), path, err)
|
||||||
}
|
}
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
hctx.GetLogger().Infof("ApiGet(%#v): %s\n", path, duration.String())
|
hctx.GetLogger().Infof("ApiGet(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
|
||||||
return respBody, nil
|
return respBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -511,7 +511,7 @@ func ApiPost(ctx context.Context, path, contentType string, reqBody []byte) ([]b
|
|||||||
return nil, fmt.Errorf("failed to read response body from POST %s: %w", GetServerHostname()+path, err)
|
return nil, fmt.Errorf("failed to read response body from POST %s: %w", GetServerHostname()+path, err)
|
||||||
}
|
}
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
hctx.GetLogger().Infof("ApiPost(%#v): %s\n", GetServerHostname()+path, duration.String())
|
hctx.GetLogger().Infof("ApiPost(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
|
||||||
return respBody, nil
|
return respBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,10 @@ type EncHistoryEntry struct {
|
|||||||
// Note that EncHistoryEntry.EncryptedId == HistoryEntry.Id (for entries created after pre-saving support)
|
// Note that EncHistoryEntry.EncryptedId == HistoryEntry.Id (for entries created after pre-saving support)
|
||||||
EncryptedId string `json:"encrypted_id"`
|
EncryptedId string `json:"encrypted_id"`
|
||||||
ReadCount int `json:"read_count"`
|
ReadCount int `json:"read_count"`
|
||||||
|
// Whether this encrypted history entry came from DeviceId. If IsFromSameDevice is true,
|
||||||
|
// then this won't be sent back by the query endpoint. We do still purposefully store
|
||||||
|
// these since they're useful for initializing new devices.
|
||||||
|
IsFromSameDevice bool `json:"is_from_same_device"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Represents a request to get all history entries from a given device. Used as part of bootstrapping
|
// Represents a request to get all history entries from a given device. Used as part of bootstrapping
|
||||||
|
Loading…
Reference in New Issue
Block a user