diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4c4ef6c3c..efe088b27 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -38,6 +39,7 @@ type GRPCServer struct { jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager + peerLocks sync.Map } // NewServer creates a new Management server @@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck @@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } + unlock() + unlock = nil + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w } func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { + unlock := s.acquirePeerLockByUID(ctx, peer.Key) + defer unlock() + + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } @@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return claims.UserId, nil } +func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) + + start := time.Now() + value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.Lock() + log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start)) + start = time.Now() + + unlock = func() { + mtx.Unlock() + log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start)) + } + + return unlock +} + // maps internal internalStatus.Error to gRPC status.Error func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok {