Compare commits

...

10 Commits

Author SHA1 Message Date
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
12 changed files with 558 additions and 80 deletions

View File

@ -126,26 +126,6 @@ func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
return err
}
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration
@ -162,6 +142,11 @@ func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
return nil, err
}
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
return &conf, ValidateDaemonConfiguration(&conf)
}

View File

@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error
// if something went wrong
GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne
// HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise.
HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone
Close() error
}
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection
conLoc sync.RWMutex
clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config
connFactory PeerConnectionFactory
}
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err
}
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if !params.SkipCertVerification {
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool,
}
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{
sync.RWMutex{},
connections,
serverConfig,
clientConfig,
params.ConnFactory,
}
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists
}
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams()
params.CaCert = ""
params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true
_, err := NewConnectionManager(params)
if err != nil {
t.Fatal(`an error should not be thrown`)
if err == nil {
t.Fatalf(`an error should be thrown`)
}
}

View File

@ -2,8 +2,11 @@ package conn
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log"
@ -14,10 +17,8 @@ import (
// ConnectionServer manages gRPC server peer connections
type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes
server *grpc.Server
// the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes
@ -48,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
}
server := grpc.NewServer(
@ -61,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider
connServer := ConnectionServer{
serverConfig: serverConfig,
server: server,
ctrlProvider: ctrlProvider,
syncProvider: syncProvider,
@ -74,7 +91,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil
}
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint]

View File

@ -158,8 +158,8 @@ type TwoPhaseStoreMeshManager struct {
IfName string
Client *wgctrl.Client
LastClock uint64
conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration
Conf *conf.WgConfiguration
DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
}
@ -204,7 +204,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot)
if err != nil {
@ -265,7 +264,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) {
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.KeepAliveTime) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
@ -411,6 +410,11 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
}
node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key)
m.store.Put(nodeId, node)
return nil
@ -510,5 +514,5 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
// GetConfiguration implements mesh.MeshProvider.
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
return m.Conf
}

442
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,442 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.DNS_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncRate: 2,
KeepAliveTime: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
BranchRate: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -18,8 +18,8 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
MeshId: params.MeshId,
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
daemonConf: params.DaemonConf,
Conf: params.Conf,
DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))

View File

@ -91,7 +91,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return p.PublicKey.String() == pubKey.String()
})
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil {
return nil, err

View File

@ -54,9 +54,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
s.manager.GetRouteManager().UpdateRoutes()
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
nodeNames := correspondingMesh.GetPeers()
if self != nil {
@ -68,13 +65,15 @@ func (s *SyncerImpl) Sync(meshId string) error {
var gossipNodes []string
// Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
if len(neighbours) == 0 {
return nil
}
redundancyLength := min(len(neighbours), 3)
gossipNodes = neighbours[:redundancyLength]
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -92,17 +91,17 @@ func (s *SyncerImpl) Sync(meshId string) error {
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", node)
continue
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
if err != nil {
logging.Log.WriteInfof(err.Error())
}
}

View File

@ -1,6 +1,7 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/conn"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"google.golang.org/grpc/codes"
@ -15,11 +16,34 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager
connManager conn.ConnectionManager
}
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
return true
}
func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return true
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
@ -29,13 +53,15 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) b
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(meshId, nodeId)
}
return false
}
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
}

View File

@ -151,6 +151,6 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
}
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
errorHdlr := NewSyncErrorHandler(s.MeshManager)
errorHdlr := NewSyncErrorHandler(s.MeshManager, s.ConnectionManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
}