Fix double-syncing error where devices receive entries from themselves

This commit is contained in:
David Dworken 2024-04-14 18:08:41 -07:00
parent 3b7f943d55
commit 3b6467956c
No known key found for this signature in database
5 changed files with 18 additions and 11 deletions

View File

@ -31,7 +31,7 @@ func (db *DB) AllHistoryEntriesForUser(ctx context.Context, userID string) ([]*s
func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) { func (db *DB) HistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
var historyEntries []*shared.EncHistoryEntry var historyEntries []*shared.EncHistoryEntry
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries) tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ? AND NOT is_from_same_device", deviceID, limit).Find(&historyEntries)
if tx.Error != nil { if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error) return nil, fmt.Errorf("tx.Error: %w", tx.Error)
@ -52,12 +52,13 @@ func (db *DB) AddHistoryEntries(ctx context.Context, entries ...*shared.EncHisto
}) })
} }
func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, devices []*Device, entries []*shared.EncHistoryEntry) error { func (db *DB) AddHistoryEntriesForAllDevices(ctx context.Context, sourceDeviceId string, devices []*Device, entries []*shared.EncHistoryEntry) error {
chunkSize := 1000 chunkSize := 1000
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, device := range devices { for _, device := range devices {
for _, entry := range entries { for _, entry := range entries {
entry.DeviceId = device.DeviceId entry.DeviceId = device.DeviceId
entry.IsFromSameDevice = sourceDeviceId == device.DeviceId
} }
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error // Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
for _, entriesChunk := range shared.Chunks(entries, chunkSize) { for _, entriesChunk := range shared.Chunks(entries, chunkSize) {

View File

@ -39,7 +39,8 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
} }
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), devices, entries) sourceDeviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), sourceDeviceId, devices, entries)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
} }
@ -49,21 +50,20 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
resp := shared.SubmitResponse{} resp := shared.SubmitResponse{}
deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment) if sourceDeviceId != "" {
if deviceId != "" {
hv, err := shared.ParseVersionString(version) hv, err := shared.ParseVersionString(version)
if err != nil || hv.GreaterThan(shared.ParsedVersion{MinorVersion: 0, MajorVersion: 221}) { if err != nil || hv.GreaterThan(shared.ParsedVersion{MinorVersion: 0, MajorVersion: 221}) {
// Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary // Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary
// since tests run with v0.Unknown which obviously fails to parse. // since tests run with v0.Unknown which obviously fails to parse.
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, sourceDeviceId)
checkGormError(err) checkGormError(err)
resp.DumpRequests = dumpRequests resp.DumpRequests = dumpRequests
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, sourceDeviceId)
checkGormError(err) checkGormError(err)
resp.DeletionRequests = deletionRequests resp.DeletionRequests = deletionRequests
checkGormError(s.db.DeletionRequestInc(r.Context(), userId, deviceId)) checkGormError(s.db.DeletionRequestInc(r.Context(), userId, sourceDeviceId))
} }
} }

View File

@ -2996,7 +2996,9 @@ func BenchmarkQuery(b *testing.B) {
// Benchmarked code: // Benchmarked code:
b.StartTimer() b.StartTimer()
ctx := hctx.MakeContext() ctx := hctx.MakeContext()
_, err := lib.Search(ctx, hctx.GetDb(ctx), "echo", 100) err := lib.RetrieveAdditionalEntriesFromRemote(ctx, "tui")
require.NoError(b, err)
_, err = lib.Search(ctx, hctx.GetDb(ctx), "echo", 100)
require.NoError(b, err) require.NoError(b, err)
b.StopTimer() b.StopTimer()
} }

View File

@ -481,7 +481,7 @@ func ApiGet(ctx context.Context, path string) ([]byte, error) {
return nil, fmt.Errorf("failed to read response body from GET %s%s: %w", GetServerHostname(), path, err) return nil, fmt.Errorf("failed to read response body from GET %s%s: %w", GetServerHostname(), path, err)
} }
duration := time.Since(start) duration := time.Since(start)
hctx.GetLogger().Infof("ApiGet(%#v): %s\n", path, duration.String()) hctx.GetLogger().Infof("ApiGet(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
return respBody, nil return respBody, nil
} }
@ -511,7 +511,7 @@ func ApiPost(ctx context.Context, path, contentType string, reqBody []byte) ([]b
return nil, fmt.Errorf("failed to read response body from POST %s: %w", GetServerHostname()+path, err) return nil, fmt.Errorf("failed to read response body from POST %s: %w", GetServerHostname()+path, err)
} }
duration := time.Since(start) duration := time.Since(start)
hctx.GetLogger().Infof("ApiPost(%#v): %s\n", GetServerHostname()+path, duration.String()) hctx.GetLogger().Infof("ApiPost(%#v): %d bytes - %s\n", GetServerHostname()+path, len(respBody), duration.String())
return respBody, nil return respBody, nil
} }

View File

@ -19,6 +19,10 @@ type EncHistoryEntry struct {
// Note that EncHistoryEntry.EncryptedId == HistoryEntry.Id (for entries created after pre-saving support) // Note that EncHistoryEntry.EncryptedId == HistoryEntry.Id (for entries created after pre-saving support)
EncryptedId string `json:"encrypted_id"` EncryptedId string `json:"encrypted_id"`
ReadCount int `json:"read_count"` ReadCount int `json:"read_count"`
// Whether this encrypted history entry came from DeviceId. If IsFromSameDevice is true,
// then this won't be sent back by the query endpoint. We do still purposefully store
// these since they're useful for initializing new devices.
IsFromSameDevice bool `json:"is_from_same_device"`
} }
// Represents a request to get all history entries from a given device. Used as part of bootstrapping // Represents a request to get all history entries from a given device. Used as part of bootstrapping