mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 03:23:56 +01:00
[management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server * remove sleep and put db update first * don't export lock method
This commit is contained in:
parent
669904cd06
commit
67ce14eaea
@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pb "github.com/golang/protobuf/proto" // nolint
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
@ -38,6 +39,7 @@ type GRPCServer struct {
|
|||||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
|
peerLocks sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||||
|
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||||
|
defer func() {
|
||||||
|
if unlock != nil {
|
||||||
|
unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
@ -190,6 +199,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,9 +257,12 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,6 +289,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
return claims.UserId, nil
|
return claims.UserId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||||
|
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||||
|
mtx := value.(*sync.RWMutex)
|
||||||
|
mtx.Lock()
|
||||||
|
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
start = time.Now()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
mtx.Unlock()
|
||||||
|
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
// maps internal internalStatus.Error to gRPC status.Error
|
// maps internal internalStatus.Error to gRPC status.Error
|
||||||
func mapError(ctx context.Context, err error) error {
|
func mapError(ctx context.Context, err error) error {
|
||||||
if e, ok := internalStatus.FromError(err); ok {
|
if e, ok := internalStatus.FromError(err); ok {
|
||||||
|
Loading…
Reference in New Issue
Block a user