mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-28 02:53:20 +01:00
Use accountID retrieved from the sync call to acquire read lock sooner (#2369)
Use accountID retrieved from the sync call to acquire read lock sooner and avoiding extra DB calls. - Use the account ID across sync calls - Moved account read lock - Renamed CancelPeerRoutines to OnPeerDisconnected - Added race tests
This commit is contained in:
parent
02f3105e48
commit
cbf9f2058e
@ -135,8 +135,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
CancelPeerRoutines(ctx context.Context, peerPubKey string) error
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
@ -1857,22 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
// acquiring peer write lock here is ok since we only modify peer information that is supplied by the
|
||||
// peer itself which can't be modified by API, and it only happens after an account read lock is acquired
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -1892,22 +1881,11 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
|
||||
return peer, netMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error {
|
||||
// acquiring peer write lock here is ok since we only modify peer information that is supplied by the
|
||||
// peer itself which can't be modified by API, and it only happens after an account read lock is acquired
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
|
@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(ctx, peer.Key)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, &Config{
|
||||
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@ -399,32 +401,35 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
@ -434,7 +439,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
return s, accountManager, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
|
||||
@ -454,3 +459,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
||||
|
||||
return mgmtProto.NewManagementServiceClient(conn), conn, nil
|
||||
}
|
||||
func Test_SyncStatusRace(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
||||
t.Skip("Skipping on CI and Postgres store")
|
||||
}
|
||||
for i := 0; i < 500; i++ {
|
||||
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
|
||||
testSyncStatusRace(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
func testSyncStatusRace(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
|
||||
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Proto: "http",
|
||||
URI: "signal.wiretrustee.com:10000",
|
||||
},
|
||||
Datadir: dir,
|
||||
HttpConfig: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, clientConn, err := createRawClient(mgmtAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
||||
// there are two peers already in the store, add two more
|
||||
peers, err := registerPeers(2, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
serverKey, err := getServerKey(client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
concurrentPeerKey2 := peers[1]
|
||||
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
|
||||
|
||||
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
||||
Body: message2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp2 := &mgmtProto.EncryptedMessage{}
|
||||
err = sync2.RecvMsg(resp2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peerWithInvalidStatus := peers[0]
|
||||
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
|
||||
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
// take the first registered peer as a base for the test. Total four.
|
||||
|
||||
resp := &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
cancelFunc2()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
cancelFunc()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ctx, cancelFunc = context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp = &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if !peer.Status.Connected {
|
||||
t.Fatal("Peer should be connected")
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ type MockAccountManager struct {
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
|
||||
@ -105,14 +105,14 @@ type MockAccountManager struct {
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user