diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index 343b8a9..e6407b6 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -48,6 +48,13 @@ type MeshNode struct { Description string Services map[string]string Type string + Tombstone bool +} + +// Mark: marks the node is unreachable. This is not broadcast on +// syncrhonisation +func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) { + m.store.Mark(nodeId) } // GetHostEndpoint: gets the gRPC endpoint of the node @@ -200,11 +207,11 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte { // Load() loads a mesh network func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error { buf := bytes.NewBuffer(bs) - dec := gob.NewDecoder(buf) var snapshot TwoPhaseMapSnapshot[string, MeshNode] err := dec.Decode(&snapshot) + m.store.Merge(snapshot) return err } @@ -256,14 +263,25 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route node := m.store.Get(nodeId) + changes := false + for _, route := range routes { - node.Routes[route.GetDestination().String()] = Route{ - Destination: route.GetDestination().String(), - Path: route.GetPath(), + prevRoute, ok := node.Routes[route.GetDestination().String()] + + if !ok || route.GetHopCount() < prevRoute.GetHopCount() { + changes = true + + node.Routes[route.GetDestination().String()] = Route{ + Destination: route.GetDestination().String(), + Path: route.GetPath(), + } } } - m.store.Put(nodeId, node) + if changes { + m.store.Put(nodeId, node) + } + return nil } @@ -357,8 +375,18 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro } // Prune: prunes all nodes that have not updated their timestamp in -// pruneAmount seconds +// pruneAmount of seconds func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error { + nodes := lib.MapValues(m.store.AsMap()) + nodes = lib.Filter(nodes, func(mn MeshNode) bool { + return time.Now().Unix()-mn.Timestamp > int64(pruneAmount) + }) + + for _, node := range nodes { + key, _ := node.GetPublicKey() + m.store.Remove(key.String()) + } + return nil } @@ -370,6 +398,13 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string { return false } + // If the node is marked as unreachable don't consider it a peer. + // this help to optimize convergence time for unreachable nodes. + // However advertising it to other nodes could result in flapping. + if m.store.IsMarked(mn.PublicKey) { + return false + } + return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime) }) diff --git a/pkg/crdt/g_map.go b/pkg/crdt/g_map.go index 81b916a..505a166 100644 --- a/pkg/crdt/g_map.go +++ b/pkg/crdt/g_map.go @@ -6,8 +6,9 @@ import ( ) type Bucket[D any] struct { - Vector uint64 - Contents D + Vector uint64 + Contents D + Gravestone bool } // GMap is a set that can only grow in size @@ -62,6 +63,30 @@ func (g *GMap[K, D]) Get(key K) D { return g.get(key).Contents } +func (g *GMap[K, D]) Mark(key K) { + g.lock.Lock() + bucket := g.contents[key] + bucket.Gravestone = true + g.lock.Unlock() +} + +// IsMarked: returns true if the node is marked +func (g *GMap[K, D]) IsMarked(key K) bool { + marked := false + + g.lock.RLock() + + bucket, ok := g.contents[key] + + if ok { + marked = bucket.Gravestone + } + + g.lock.RUnlock() + + return marked +} + func (g *GMap[K, D]) Keys() []K { g.lock.RLock() diff --git a/pkg/crdt/two_phase_map.go b/pkg/crdt/two_phase_map.go index 931deec..ee19673 100644 --- a/pkg/crdt/two_phase_map.go +++ b/pkg/crdt/two_phase_map.go @@ -60,6 +60,10 @@ func (m *TwoPhaseMap[K, D]) Put(key K, data D) { m.addMap.Put(key, data) } +func (m *TwoPhaseMap[K, D]) Mark(key K) { + m.addMap.Mark(key) +} + // Remove removes the value from the map func (m *TwoPhaseMap[K, D]) Remove(key K) { m.removeMap.Put(key, true) @@ -115,6 +119,10 @@ type TwoPhaseMapState[K comparable] struct { RemoveContents map[K]uint64 } +func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool { + return m.addMap.IsMarked(key) +} + func (m *TwoPhaseMap[K, D]) incrementClock() uint64 { maxClock := uint64(0) m.lock.Lock() @@ -184,11 +192,15 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { m.lock.Lock() for key, value := range snapshot.Add { + // Gravestone is local only to that node. + // Discover ourselves if the node is alive + value.Gravestone = false m.addMap.put(key, value) m.vectors[key] = max(value.Vector, m.vectors[key]) } for key, value := range snapshot.Remove { + value.Gravestone = false m.removeMap.put(key, value) m.vectors[key] = max(value.Vector, m.vectors[key]) } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 23e33eb..4b40b45 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -81,6 +81,11 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// Mark implements MeshProvider. +func (*MeshProviderStub) Mark(nodeId string) { + panic("unimplemented") +} + // RemoveNode implements MeshProvider. func (*MeshProviderStub) RemoveNode(nodeId string) error { panic("unimplemented") @@ -117,7 +122,7 @@ func (*MeshProviderStub) RemoveService(nodeId string, key string) error { // SetAlias implements MeshProvider. func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { - panic("unimplemented") + return nil } // RemoveRoutes implements MeshProvider. diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 251805a..d6f4ccf 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -140,6 +140,9 @@ type MeshProvider interface { GetRoutes(targetNode string) (map[string]Route, error) // RemoveNode(): remove the node from the mesh RemoveNode(nodeId string) error + // Mark: marks the node as unreachable. This is not broadcast to the entire + // this is not considered when syncing node state + Mark(nodeId string) } // HostParameters contains the IDs of a node diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 0090d63..d450837 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -1,8 +1,8 @@ package sync import ( + "io" "math/rand" - "sync" "time" "github.com/tim-beatham/wgmesh/pkg/conf" @@ -59,36 +59,42 @@ func (s *SyncerImpl) Sync(meshId string) error { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { logging.Log.WriteInfof("Sending to random cluster") - interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String()) - randomSubset = append(randomSubset, interCluster) + randomSubset[len(randomSubset)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) } - var waitGroup sync.WaitGroup + var succeeded bool = false - for index := range randomSubset { - waitGroup.Add(1) + // Do this synchronously to conserve bandwidth + for _, node := range randomSubset { + correspondingPeer := s.manager.GetNode(meshId, node) - go func(i int) error { - defer waitGroup.Done() + if correspondingPeer == nil { + logging.Log.WriteErrorf("node %s does not exist", node) + } - correspondingPeer := s.manager.GetNode(meshId, randomSubset[i]) + err = s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint()) - if correspondingPeer == nil { - logging.Log.WriteErrorf("node %s does not exist", randomSubset[i]) - } - - err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint()) - return err - }(index) + if err == nil || err == io.EOF { + succeeded = true + } else { + // If the synchronisation operation has failed them mark a gravestone + // preventing the peer from being re-contacted until it has updated + // itself + s.manager.GetMesh(meshId).Mark(node) + } } - waitGroup.Wait() - s.syncCount++ logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before)) logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) + + if !succeeded { + // If could not gossip with anyone then repeat. + s.infectionCount++ + } + s.manager.GetMesh(meshId).SaveChanges() return nil }