Use accountID retrieved from the sync call to acquire read lock sooner (#2369)

Use accountID retrieved from the sync call to acquire read lock sooner and avoiding extra DB calls.
- Use the account ID across sync calls
- Moved account read lock
- Renamed CancelPeerRoutines to OnPeerDisconnected
- Added race tests
This commit is contained in:
Maycon Santos 2024-08-01 16:21:43 +02:00 committed by GitHub
parent 02f3105e48
commit cbf9f2058e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 198 additions and 53 deletions

View File

@ -135,8 +135,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CancelPeerRoutines(ctx context.Context, peerPubKey string) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@ -1857,22 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
// acquiring peer write lock here is ok since we only modify peer information that is supplied by the
// peer itself which can't be modified by API, and it only happens after an account read lock is acquired
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
}
return nil, nil, nil, err
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@ -1892,22 +1881,11 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
return peer, netMap, postureChecks, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error {
// acquiring peer write lock here is ok since we only modify peer information that is supplied by the
// peer itself which can't be modified by API, and it only happens after an account read lock is acquired
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return status.Errorf(status.Unauthenticated, "peer not registered")
}
return err
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {

View File

@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(ctx, err)
}
@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err
}
@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err()
}
}
@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, peer)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(ctx, peer.Key)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}

View File

@ -2,6 +2,7 @@ package server
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
@ -16,6 +17,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, mgmtAddr, err := startManagement(t, &Config{
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@ -399,32 +401,35 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
return nil, nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
return nil, nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, "", err
return nil, nil, "", err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil {
return nil, "", err
return nil, nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
@ -434,7 +439,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
}
}()
return s, lis.Addr().String(), nil
return s, accountManager, lis.Addr().String(), nil
}
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
@ -454,3 +459,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn, nil
}
func Test_SyncStatusRace(t *testing.T) {
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
}
for i := 0; i < 500; i++ {
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
testSyncStatusRace(t)
})
}
}
func testSyncStatusRace(t *testing.T) {
t.Helper()
dir := t.TempDir()
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
if err != nil {
t.Fatal(err)
}
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
}},
TURNConfig: &TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Proto: "udp",
URI: "turn:stun.wiretrustee.com:3468",
}},
},
Signal: &Host{
Proto: "http",
URI: "signal.wiretrustee.com:10000",
},
Datadir: dir,
HttpConfig: nil,
})
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
client, clientConn, err := createRawClient(mgmtAddr)
if err != nil {
t.Fatal(err)
return
}
defer clientConn.Close()
// there are two peers already in the store, add two more
peers, err := registerPeers(2, client)
if err != nil {
t.Fatal(err)
return
}
serverKey, err := getServerKey(client)
if err != nil {
t.Fatal(err)
return
}
concurrentPeerKey2 := peers[1]
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
if err != nil {
t.Fatal(err)
return
}
ctx2, cancelFunc2 := context.WithCancel(context.Background())
//client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2,
})
if err != nil {
t.Fatal(err)
return
}
resp2 := &mgmtProto.EncryptedMessage{}
err = sync2.RecvMsg(resp2)
if err != nil {
t.Fatal(err)
return
}
peerWithInvalidStatus := peers[0]
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
if err != nil {
t.Fatal(err)
return
}
ctx, cancelFunc := context.WithCancel(context.Background())
//client.
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
// take the first registered peer as a base for the test. Total four.
resp := &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
cancelFunc2()
time.Sleep(1 * time.Millisecond)
cancelFunc()
time.Sleep(10 * time.Millisecond)
ctx, cancelFunc = context.WithCancel(context.Background())
defer cancelFunc()
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
})
if err != nil {
t.Fatal(err)
return
}
resp = &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(resp)
if err != nil {
t.Fatal(err)
return
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
}
if !peer.Status.Connected {
t.Fatal("Peer should be connected")
}
}

View File

@ -31,7 +31,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
@ -105,14 +105,14 @@ type MockAccountManager struct {
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error {
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
}