From 1f48fdf6cadca8576bca578dca829e9672cb98ce Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 26 Jul 2024 07:49:05 +0200 Subject: [PATCH] Add SavePeer method to prevent a possible account inconsistency (#2296) SyncPeer was storing the account with a simple read lock This change introduces the SavePeer method to the store to be used in these cases --- management/server/file_store.go | 20 ++++++++++++ management/server/peer.go | 5 +-- management/server/sql_store.go | 48 +++++++++++++++++++++++----- management/server/sql_store_test.go | 49 +++++++++++++++++++++++++++++ management/server/store.go | 4 ++- 5 files changed, 116 insertions(+), 10 deletions(-) diff --git a/management/server/file_store.go b/management/server/file_store.go index c649602e2..9a1462832 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { return s.persist(ctx, s.storeFile) } +// SavePeer saves the peer in the account +func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + newPeer := peer.Copy() + + account.Peers[peer.ID] = newPeer + + s.PeerKeyID2AccountID[peer.Key] = accountID + s.PeerID2AccountID[peer.ID] = accountID + + return nil +} + // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // PeerStatus will be saved eventually when some other changes occur. func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { diff --git a/management/server/peer.go b/management/server/peer.go index b8605fbb7..ec8b773b0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,10 +7,11 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/management/server/posture" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -539,7 +540,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac peer, updated := updatePeerMeta(peer, sync.Meta, account) if updated { - err = am.Store.SaveAccount(ctx, account) + err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { return nil, nil, nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 37cc10d8b..7648538c3 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -31,8 +31,10 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -271,6 +273,38 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } +func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. + peerCopy := peer.Copy() + peerCopy.AccountID = accountID + + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // check if peer exists before saving + var peerID string + result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) + if result.Error != nil { + return result.Error + } + + if peerID == "" { + return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID) + } + + result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) + if result.Error != nil { + return result.Error + } + + return nil + }) + + if err != nil { + return err + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -281,14 +315,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe } result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). - Where("account_id = ? AND id = ?", accountID, peerID). + Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "peer %s not found", peerID) + return status.Errorf(status.NotFound, peerNotFoundFMT, peerID) } return nil @@ -302,7 +336,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P peerCopy.Location = peerWithLocation.Location result := s.db.Model(&nbpeer.Peer{}). - Where("account_id = ? and id = ?", accountID, peerWithLocation.ID). + Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) if result.Error != nil { @@ -310,7 +344,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) + return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) } return nil @@ -644,7 +678,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { var user User - result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) + result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "user %s not found", userID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 7f48810d7..ce4ee531a 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } +func TestSqlite_SavePeer(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") + require.NoError(t, err) + + // save status of non-existing peer + peer := &nbpeer.Peer{ + Key: "peerkey", + ID: "testpeer", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + ctx := context.Background() + err = store.SavePeer(ctx, account.Id, peer) + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + + // save new status of existing peer + account.Peers[peer.ID] = peer + + err = store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + updatedPeer := peer.Copy() + updatedPeer.Status.Connected = false + updatedPeer.Meta.Hostname = "updatedpeer" + + err = store.SavePeer(ctx, account.Id, updatedPeer) + require.NoError(t, err) + + account, err = store.GetAccount(context.Background(), account.Id) + require.NoError(t, err) + + actual := account.Peers[peer.ID] + assert.Equal(t, updatedPeer.Status, actual.Status) + assert.Equal(t, updatedPeer.Meta, actual.Meta) +} + func TestSqlite_SavePeerStatus(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") @@ -414,6 +462,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { actual = account.Peers["testpeer"].Status assert.Equal(t, newStatus, *actual) } + func TestSqlite_SavePeerLocation(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") diff --git a/management/server/store.go b/management/server/store.go index 3ba73e8c7..15a419c78 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,10 +12,11 @@ import ( "strings" "time" - nbgroup "github.com/netbirdio/netbird/management/server/group" log "github.com/sirupsen/logrus" "gorm.io/gorm" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -54,6 +55,7 @@ type Store interface { AcquireAccountReadLock(ctx context.Context, accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error