mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 03:23:56 +01:00
Add SavePeer method to prevent a possible account inconsistency (#2296)
SyncPeer was storing the account with a simple read lock This change introduces the SavePeer method to the store to be used in these cases
This commit is contained in:
parent
45fd1e9c21
commit
1f48fdf6ca
@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
|
|||||||
return s.persist(ctx, s.storeFile)
|
return s.persist(ctx, s.storeFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SavePeer saves the peer in the account
|
||||||
|
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.getAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeer := peer.Copy()
|
||||||
|
|
||||||
|
account.Peers[peer.ID] = newPeer
|
||||||
|
|
||||||
|
s.PeerKeyID2AccountID[peer.Key] = accountID
|
||||||
|
s.PeerID2AccountID[peer.ID] = accountID
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
||||||
// PeerStatus will be saved eventually when some other changes occur.
|
// PeerStatus will be saved eventually when some other changes occur.
|
||||||
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||||
|
@ -7,10 +7,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@ -539,7 +540,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
|
|
||||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||||
if updated {
|
if updated {
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
storeSqliteFileName = "store.db"
|
storeSqliteFileName = "store.db"
|
||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||||
@ -271,6 +273,38 @@ func (s *SqlStore) GetInstallationID() string {
|
|||||||
return installation.InstallationIDValue
|
return installation.InstallationIDValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||||
|
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
||||||
|
peerCopy := peer.Copy()
|
||||||
|
peerCopy.AccountID = accountID
|
||||||
|
|
||||||
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
// check if peer exists before saving
|
||||||
|
var peerID string
|
||||||
|
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerID == "" {
|
||||||
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||||
var peerCopy nbpeer.Peer
|
var peerCopy nbpeer.Peer
|
||||||
peerCopy.Status = &peerStatus
|
peerCopy.Status = &peerStatus
|
||||||
@ -281,14 +315,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
|||||||
}
|
}
|
||||||
result := s.db.Model(&nbpeer.Peer{}).
|
result := s.db.Model(&nbpeer.Peer{}).
|
||||||
Select(fieldsToUpdate).
|
Select(fieldsToUpdate).
|
||||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||||
Updates(&peerCopy)
|
Updates(&peerCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -302,7 +336,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
|||||||
peerCopy.Location = peerWithLocation.Location
|
peerCopy.Location = peerWithLocation.Location
|
||||||
|
|
||||||
result := s.db.Model(&nbpeer.Peer{}).
|
result := s.db.Model(&nbpeer.Peer{}).
|
||||||
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
|
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
||||||
Updates(peerCopy)
|
Updates(peerCopy)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@ -310,7 +344,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -644,7 +678,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
|
|||||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||||
var user User
|
var user User
|
||||||
|
|
||||||
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||||
|
@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
|
|||||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_SavePeer(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
||||||
|
|
||||||
|
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// save status of non-existing peer
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
Key: "peerkey",
|
||||||
|
ID: "testpeer",
|
||||||
|
SetupKey: "peerkeysetupkey",
|
||||||
|
IP: net.IP{127, 0, 0, 1},
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
|
||||||
|
Name: "peer name",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
err = store.SavePeer(ctx, account.Id, peer)
|
||||||
|
assert.Error(t, err)
|
||||||
|
parsedErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||||
|
|
||||||
|
// save new status of existing peer
|
||||||
|
account.Peers[peer.ID] = peer
|
||||||
|
|
||||||
|
err = store.SaveAccount(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updatedPeer := peer.Copy()
|
||||||
|
updatedPeer.Status.Connected = false
|
||||||
|
updatedPeer.Meta.Hostname = "updatedpeer"
|
||||||
|
|
||||||
|
err = store.SavePeer(ctx, account.Id, updatedPeer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
account, err = store.GetAccount(context.Background(), account.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
actual := account.Peers[peer.ID]
|
||||||
|
assert.Equal(t, updatedPeer.Status, actual.Status)
|
||||||
|
assert.Equal(t, updatedPeer.Meta, actual.Meta)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqlite_SavePeerStatus(t *testing.T) {
|
func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
@ -414,6 +462,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
|||||||
actual = account.Peers["testpeer"].Status
|
actual = account.Peers["testpeer"].Status
|
||||||
assert.Equal(t, newStatus, *actual)
|
assert.Equal(t, newStatus, *actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlite_SavePeerLocation(t *testing.T) {
|
func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
@ -12,10 +12,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@ -54,6 +55,7 @@ type Store interface {
|
|||||||
AcquireAccountReadLock(ctx context.Context, accountID string) func()
|
AcquireAccountReadLock(ctx context.Context, accountID string) func()
|
||||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||||
AcquireGlobalLock(ctx context.Context) func()
|
AcquireGlobalLock(ctx context.Context) func()
|
||||||
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
||||||
|
Loading…
Reference in New Issue
Block a user