diff --git a/management/server/account.go b/management/server/account.go index 4648c00cd..5d3ee6dc1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -135,8 +135,8 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CancelPeerRoutines(ctx context.Context, peerPubKey string) error + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) @@ -1857,22 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C } } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - // acquiring peer write lock here is ok since we only modify peer information that is supplied by the - // peer itself which can't be modified by API, and it only happens after an account read lock is acquired - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") - } - return nil, nil, nil, err - } - +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1892,22 +1881,11 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error { - // acquiring peer write lock here is ok since we only modify peer information that is supplied by the - // peer itself which can't be modified by API, and it only happens after an account read lock is acquired - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return status.Errorf(status.Unauthenticated, "peer not registered") - } - return err - } - +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4a12a5c3e..f71a45d99 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { return mapError(ctx, err) } @@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } - return s.handleUpdates(ctx, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { for { select { // condition when there are some updates @@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { return err } @@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return srv.Context().Err() } } @@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ @@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer * Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { +func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(ctx, peer.Key) + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index e1f7787f2..2c9d43948 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "net" "os" "path/filepath" @@ -16,6 +17,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" @@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) { defer func() { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -399,32 +401,35 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) { +func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, "", err + return nil, nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { - return nil, "", err + return nil, nil, "", err } t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", + + ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck + + accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { - return nil, "", err + return nil, nil, "", err } turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) if err != nil { - return nil, "", err + return nil, nil, "", err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -434,7 +439,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) } }() - return s, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -454,3 +459,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } +func Test_SyncStatusRace(t *testing.T) { + if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { + t.Skip("Skipping on CI and Postgres store") + } + for i := 0; i < 500; i++ { + t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) { + testSyncStatusRace(t) + }) + } +} +func testSyncStatusRace(t *testing.T) { + t.Helper() + dir := t.TempDir() + err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) + if err != nil { + t.Fatal(err) + } + defer func() { + os.Remove(filepath.Join(dir, "store.json")) //nolint + }() + + mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{ + Stuns: []*Host{{ + Proto: "udp", + URI: "stun:stun.wiretrustee.com:3468", + }}, + TURNConfig: &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{}, + Secret: "whatever", + Turns: []*Host{{ + Proto: "udp", + URI: "turn:stun.wiretrustee.com:3468", + }}, + }, + Signal: &Host{ + Proto: "http", + URI: "signal.wiretrustee.com:10000", + }, + Datadir: dir, + HttpConfig: nil, + }) + if err != nil { + t.Fatal(err) + return + } + defer mgmtServer.GracefulStop() + + client, clientConn, err := createRawClient(mgmtAddr) + if err != nil { + t.Fatal(err) + return + } + + defer clientConn.Close() + + // there are two peers already in the store, add two more + peers, err := registerPeers(2, client) + if err != nil { + t.Fatal(err) + return + } + + serverKey, err := getServerKey(client) + if err != nil { + t.Fatal(err) + return + } + + concurrentPeerKey2 := peers[1] + t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String()) + + syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2) + if err != nil { + t.Fatal(err) + return + } + + ctx2, cancelFunc2 := context.WithCancel(context.Background()) + + //client. + sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{ + WgPubKey: concurrentPeerKey2.PublicKey().String(), + Body: message2, + }) + if err != nil { + t.Fatal(err) + return + } + + resp2 := &mgmtProto.EncryptedMessage{} + err = sync2.RecvMsg(resp2) + if err != nil { + t.Fatal(err) + return + } + + peerWithInvalidStatus := peers[0] + t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String()) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq) + if err != nil { + t.Fatal(err) + return + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + + //client. + sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerWithInvalidStatus.PublicKey().String(), + Body: message, + }) + if err != nil { + t.Fatal(err) + return + } + + // take the first registered peer as a base for the test. Total four. + + resp := &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(resp) + if err != nil { + t.Fatal(err) + return + } + + cancelFunc2() + time.Sleep(1 * time.Millisecond) + cancelFunc() + time.Sleep(10 * time.Millisecond) + + ctx, cancelFunc = context.WithCancel(context.Background()) + defer cancelFunc() + sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerWithInvalidStatus.PublicKey().String(), + Body: message, + }) + if err != nil { + t.Fatal(err) + return + } + + resp = &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(resp) + if err != nil { + t.Fatal(err) + return + } + + time.Sleep(10 * time.Millisecond) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) + if err != nil { + t.Fatal(err) + return + } + if !peer.Status.Connected { + t.Fatal("Peer should be connected") + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 1adf9a2d6..a66bdee2b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) @@ -105,14 +105,14 @@ type MockAccountManager struct { GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error { +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { // TODO implement me panic("implement me") }