diff --git a/pkg/conn/connectionmanager.go b/pkg/conn/connectionmanager.go index 2928ec8..f1c774a 100644 --- a/pkg/conn/connectionmanager.go +++ b/pkg/conn/connectionmanager.go @@ -22,6 +22,8 @@ type ConnectionManager interface { // HasConnections returns true if a peer has already registered at the given // endpoint or false otherwise. HasConnection(endPoint string) bool + // Removes a connection if it exists + RemoveConnection(endPoint string) error // Goes through all the connections and closes eachone Close() error } @@ -150,6 +152,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool { return exists } +// RemoveConnection removes the given connection if it exists +func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error { + m.conLoc.Lock() + err := m.clientConnections[endPoint].Close() + + delete(m.clientConnections, endPoint) + m.conLoc.Unlock() + return err +} func (m *ConnectionManagerImpl) Close() error { for _, conn := range m.clientConnections { if err := conn.Close(); err != nil { diff --git a/pkg/conn/stub.go b/pkg/conn/stub.go index 18c53f2..203bb1b 100644 --- a/pkg/conn/stub.go +++ b/pkg/conn/stub.go @@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection, return mock, nil } +func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error { + delete(s.Endpoints, endPoint) + return nil +} + func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) { endpoint, ok := s.Endpoints[endPoint] diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 99fcbab..9cefe38 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -96,11 +96,6 @@ func (s *SyncerImpl) Sync(meshId string) error { if err == nil || err == io.EOF { succeeded = true - } else if self.GetType() == conf.PEER_ROLE { - // If the synchronisation operation has failed them mark a gravestone - // preventing the peer from being re-contacted until it has updated - // itself - s.manager.GetMesh(meshId).Mark(node) } if err != nil { diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index b10c358..412113a 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -1,6 +1,7 @@ package sync import ( + "github.com/tim-beatham/wgmesh/pkg/conn" logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/mesh" "google.golang.org/grpc/codes" @@ -15,6 +16,7 @@ type SyncErrorHandler interface { // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler type SyncErrorHandlerImpl struct { meshManager mesh.MeshManager + connManager conn.ConnectionManager } func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { @@ -23,14 +25,33 @@ func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { return true } +func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId string) bool { + mesh := s.meshManager.GetMesh(nodeId) + + if mesh == nil { + return true + } + + node, err := mesh.GetNode(nodeId) + + if err != nil { + return false + } + + s.connManager.RemoveConnection(node.GetHostEndpoint()) + return true +} + func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool { errStatus, _ := status.FromError(err) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) switch errStatus.Code() { - case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: + case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound: return s.handleFailed(meshId, nodeId) + case codes.DeadlineExceeded: + return s.handleDeadlineExceeded(meshId, nodeId) } return false