Add SavePeer method to prevent a possible account inconsistency (#2296)

SyncPeer was storing the account with a simple read lock

This change introduces the SavePeer method to the store to be used in these cases
This commit is contained in:
Maycon Santos 2024-07-26 07:49:05 +02:00 committed by GitHub
parent 45fd1e9c21
commit 1f48fdf6ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 116 additions and 10 deletions

View File

@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
return s.persist(ctx, s.storeFile) return s.persist(ctx, s.storeFile)
} }
// SavePeer saves the peer in the account
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
newPeer := peer.Copy()
account.Peers[peer.ID] = newPeer
s.PeerKeyID2AccountID[peer.Key] = accountID
s.PeerID2AccountID[peer.ID] = accountID
return nil
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur. // PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {

View File

@ -7,10 +7,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -539,7 +540,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
peer, updated := updatePeerMeta(peer, sync.Meta, account) peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated { if updated {
err = am.Store.SaveAccount(ctx, account) err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

View File

@ -31,8 +31,10 @@ import (
) )
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
) )
// SqlStore represents an account storage backed by a Sql DB persisted to disk // SqlStore represents an account storage backed by a Sql DB persisted to disk
@ -271,6 +273,38 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue return installation.InstallationIDValue
} }
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
if result.Error != nil {
return result.Error
}
if peerID == "" {
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
}
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
}
return nil
})
if err != nil {
return err
}
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@ -281,14 +315,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
} }
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where("account_id = ? AND id = ?", accountID, peerID). Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy) Updates(&peerCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID) return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
} }
return nil return nil
@ -302,7 +336,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
peerCopy.Location = peerWithLocation.Location peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy) Updates(peerCopy)
if result.Error != nil { if result.Error != nil {
@ -310,7 +344,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
} }
return nil return nil
@ -644,7 +678,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID) return status.Errorf(status.NotFound, "user %s not found", userID)

View File

@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
} }
func TestSqlite_SavePeer(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/store.json")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
// save status of non-existing peer
peer := &nbpeer.Peer{
Key: "peerkey",
ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
ctx := context.Background()
err = store.SavePeer(ctx, account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
// save new status of existing peer
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
updatedPeer := peer.Copy()
updatedPeer.Status.Connected = false
updatedPeer.Meta.Hostname = "updatedpeer"
err = store.SavePeer(ctx, account.Id, updatedPeer)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID]
assert.Equal(t, updatedPeer.Status, actual.Status)
assert.Equal(t, updatedPeer.Meta, actual.Meta)
}
func TestSqlite_SavePeerStatus(t *testing.T) { func TestSqlite_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
@ -414,6 +462,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
actual = account.Peers["testpeer"].Status actual = account.Peers["testpeer"].Status
assert.Equal(t, newStatus, *actual) assert.Equal(t, newStatus, *actual)
} }
func TestSqlite_SavePeerLocation(t *testing.T) { func TestSqlite_SavePeerLocation(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@ -12,10 +12,11 @@ import (
"strings" "strings"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -54,6 +55,7 @@ type Store interface {
AcquireAccountReadLock(ctx context.Context, accountID string) func() AcquireAccountReadLock(ctx context.Context, accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func() AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error