From 661fb0d54ce005b2aec558f36f7610f754fdb36c Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Thu, 7 Dec 2023 18:18:13 +0000 Subject: [PATCH] 45-use-statistical-testing Keepalive is based on per mesh and not per node. Using total ordering mechanism similar to paxos to elect a leader if leader doesn't update it's timestamp within 3 * keepAlive then give the leader a gravestone and elect the next leader. Leader is bassed on lexicographically ordered public key. --- cmd/wgmeshd/main.go | 4 ++ pkg/crdt/datastore.go | 26 +++++++++ pkg/crdt/factory.go | 7 ++- pkg/crdt/g_map.go | 12 ++-- pkg/crdt/two_phase_map.go | 14 +++-- pkg/crdt/two_phase_map_syncer.go | 52 +++++++++++++++-- pkg/crdt/vector_clock.go | 95 +++++++++++++++++++++++++++----- pkg/lib/conv.go | 8 ++- pkg/lib/stats.go | 3 +- pkg/mesh/manager.go | 2 - pkg/mesh/types.go | 2 +- pkg/sync/syncer.go | 59 +++++++++++++------- pkg/wg/wg.go | 4 +- 13 files changed, 224 insertions(+), 64 deletions(-) diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index 5cdf670..d495830 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -13,6 +13,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/sync" + timer "github.com/tim-beatham/wgmesh/pkg/timers" "golang.zx2c4.com/wireguard/wgctrl" ) @@ -62,6 +63,7 @@ func main() { syncRequester = sync.NewSyncRequester(ctrlServer) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer) + keepAlive := timer.NewTimestampScheduler(ctrlServer) robinIpcParams := robin.RobinIpcParams{ CtrlServer: ctrlServer, @@ -79,10 +81,12 @@ func main() { go ipc.RunIpcHandler(&robinIpc) go syncScheduler.Run() + go keepAlive.Run() closeResources := func() { logging.Log.WriteInfof("Closing resources") syncScheduler.Stop() + keepAlive.Stop() ctrlServer.Close() client.Close() } diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index b6b4b93..638af18 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -5,6 +5,7 @@ import ( "encoding/gob" "fmt" "net" + "slices" "strings" "time" @@ -245,6 +246,31 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) } + // Sort nodes by their public key + peers := m.GetPeers() + slices.Sort(peers) + + if len(peers) == 0 { + return nil + } + + peerToUpdate := peers[0] + + if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) { + m.store.Mark(peerToUpdate) + + if len(peers) < 2 { + return nil + } + + peerToUpdate = peers[1] + } + + if peerToUpdate != nodeId { + return nil + } + + // Refresh causing node to update it's time stamp node := m.store.Get(nodeId) node.Timestamp = time.Now().Unix() m.store.Put(nodeId, node) diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go index 5e2ddc6..571d430 100644 --- a/pkg/crdt/factory.go +++ b/pkg/crdt/factory.go @@ -2,6 +2,7 @@ package crdt import ( "fmt" + "hash/fnv" "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/lib" @@ -16,7 +17,11 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) IfName: params.DevName, Client: params.Client, conf: params.Conf, - store: NewTwoPhaseMap[string, MeshNode](params.NodeID), + store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { + h := fnv.New32a() + h.Write([]byte(s)) + return uint64(h.Sum32()) + }, uint64(3*params.Conf.KeepAliveTime)), }, nil } diff --git a/pkg/crdt/g_map.go b/pkg/crdt/g_map.go index db5dd55..1d5bab4 100644 --- a/pkg/crdt/g_map.go +++ b/pkg/crdt/g_map.go @@ -2,9 +2,8 @@ package crdt import ( + "cmp" "sync" - - "github.com/tim-beatham/wgmesh/pkg/lib" ) type Bucket[D any] struct { @@ -14,7 +13,7 @@ type Bucket[D any] struct { } // GMap is a set that can only grow in size -type GMap[K comparable, D any] struct { +type GMap[K cmp.Ordered, D any] struct { lock sync.RWMutex contents map[K]Bucket[D] clock *VectorClock[K] @@ -155,18 +154,17 @@ func (g *GMap[K, D]) GetHash() uint64 { } func (g *GMap[K, D]) Prune() { - outliers := lib.GetOutliers(g.clock.GetClock(), 0.05) - + stale := g.clock.getStale() g.lock.Lock() - for _, outlier := range outliers { + for _, outlier := range stale { delete(g.contents, outlier) } g.lock.Unlock() } -func NewGMap[K comparable, D any](clock *VectorClock[K]) *GMap[K, D] { +func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] { return &GMap[K, D]{ contents: make(map[K]Bucket[D]), clock: clock, diff --git a/pkg/crdt/two_phase_map.go b/pkg/crdt/two_phase_map.go index 8071e32..452ba81 100644 --- a/pkg/crdt/two_phase_map.go +++ b/pkg/crdt/two_phase_map.go @@ -1,17 +1,19 @@ package crdt import ( + "cmp" + "github.com/tim-beatham/wgmesh/pkg/lib" ) -type TwoPhaseMap[K comparable, D any] struct { +type TwoPhaseMap[K cmp.Ordered, D any] struct { addMap *GMap[K, D] removeMap *GMap[K, bool] Clock *VectorClock[K] processId K } -type TwoPhaseMapSnapshot[K comparable, D any] struct { +type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct { Add map[K]Bucket[D] Remove map[K]Bucket[bool] } @@ -104,7 +106,7 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh } } -type TwoPhaseMapState[K comparable] struct { +type TwoPhaseMapState[K cmp.Ordered] struct { Vectors map[K]uint64 AddContents map[K]uint64 RemoveContents map[K]uint64 @@ -154,7 +156,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa } } - for key, value := range state.AddContents { + for key, value := range state.RemoveContents { otherValue, ok := m.RemoveContents[key] if !ok || otherValue < value { @@ -188,10 +190,10 @@ func (m *TwoPhaseMap[K, D]) Prune() { // NewTwoPhaseMap: create a new two phase map. Consists of two maps // a grow map and a remove map. If both timestamps equal then favour keeping // it in the map -func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] { +func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] { m := TwoPhaseMap[K, D]{ processId: processId, - Clock: NewVectorClock(processId), + Clock: NewVectorClock(processId, hashKey, staleTime), } m.addMap = NewGMap[K, D](m.Clock) diff --git a/pkg/crdt/two_phase_map_syncer.go b/pkg/crdt/two_phase_map_syncer.go index 4fe9fc1..41e372f 100644 --- a/pkg/crdt/two_phase_map_syncer.go +++ b/pkg/crdt/two_phase_map_syncer.go @@ -10,7 +10,8 @@ import ( type SyncState int const ( - PREPARE SyncState = iota + HASH SyncState = iota + PREPARE PRESENT EXCHANGE MERGE @@ -26,13 +27,51 @@ type TwoPhaseSyncer struct { peerMsg []byte } +type TwoPhaseHash struct { + Hash uint64 +} + type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool) -func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) { +func hash(syncer *TwoPhaseSyncer) ([]byte, bool) { + hash := TwoPhaseHash{ + Hash: syncer.manager.store.Clock.GetHash(), + } + var buffer bytes.Buffer enc := gob.NewEncoder(&buffer) - err := enc.Encode(*syncer.mapState) + err := enc.Encode(hash) + + if err != nil { + logging.Log.WriteInfof(err.Error()) + } + + syncer.IncrementState() + return buffer.Bytes(), true +} + +func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) { + var recvBuffer = bytes.NewBuffer(syncer.peerMsg) + dec := gob.NewDecoder(recvBuffer) + + var hash TwoPhaseHash + err := dec.Decode(&hash) + + if err != nil { + logging.Log.WriteInfof(err.Error()) + } + + // If vector clocks are equal then no need to merge state + // Helps to reduce bandwidth by detecting early + if hash.Hash == syncer.manager.store.Clock.GetHash() { + return nil, false + } + + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + + err = enc.Encode(*syncer.mapState) if err != nil { logging.Log.WriteInfof(err.Error()) @@ -124,11 +163,14 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error { func (t *TwoPhaseSyncer) Complete() { logging.Log.WriteInfof("SYNC COMPLETED") - t.manager.store.Clock.IncrementClock() + if t.state == FINISHED || t.state == MERGE { + t.manager.store.Clock.IncrementClock() + } } func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer { var generateMessageFsm SyncFSM = SyncFSM{ + HASH: hash, PREPARE: prepare, PRESENT: present, EXCHANGE: exchange, @@ -137,7 +179,7 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer { return &TwoPhaseSyncer{ manager: manager, - state: PREPARE, + state: HASH, mapState: manager.store.GenerateMessage(), generateMessageFSM: generateMessageFsm, } diff --git a/pkg/crdt/vector_clock.go b/pkg/crdt/vector_clock.go index 35dd171..0439efa 100644 --- a/pkg/crdt/vector_clock.go +++ b/pkg/crdt/vector_clock.go @@ -1,17 +1,29 @@ package crdt import ( + "cmp" + "slices" "sync" + "time" "github.com/tim-beatham/wgmesh/pkg/lib" ) +type VectorBucket struct { + // clock current value of the node's clock + clock uint64 + // lastUpdate we've seen + lastUpdate uint64 +} + // Vector clock defines an abstract data type // for a vector clock implementation -type VectorClock[K comparable] struct { - vectors map[K]uint64 +type VectorClock[K cmp.Ordered] struct { + vectors map[K]*VectorBucket lock sync.RWMutex processID K + staleTime uint64 + hashFunc func(K) uint64 } // IncrementClock: increments the node's value in the vector clock @@ -20,10 +32,16 @@ func (m *VectorClock[K]) IncrementClock() uint64 { m.lock.Lock() for _, value := range m.vectors { - maxClock = max(maxClock, value) + maxClock = max(maxClock, value.clock) } - m.vectors[m.processID] = maxClock + 1 + newBucket := VectorBucket{ + clock: maxClock + 1, + lastUpdate: uint64(time.Now().Unix()), + } + + m.vectors[m.processID] = &newBucket + m.lock.Unlock() return maxClock } @@ -33,29 +51,73 @@ func (m *VectorClock[K]) IncrementClock() uint64 { func (m *VectorClock[K]) GetHash() uint64 { m.lock.RLock() - sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 { - return current + sum - }) + hash := uint64(0) + + sortedKeys := lib.MapKeys(m.vectors) + slices.Sort(sortedKeys) + + for key, bucket := range m.vectors { + hash += m.hashFunc(key) + hash += bucket.clock + } m.lock.RUnlock() - return sum + return hash +} + +// getStale: get all entries that are stale within the mesh +func (m *VectorClock[K]) getStale() []K { + m.lock.RLock() + maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 { + return max(i, vb.lastUpdate) + }) + + toRemove := make([]K, 0) + + for key, bucket := range m.vectors { + if maxTimeStamp-bucket.lastUpdate > m.staleTime { + toRemove = append(toRemove, key) + } + } + + m.lock.RUnlock() + return toRemove } func (m *VectorClock[K]) Prune() { - outliers := lib.GetOutliers(m.vectors, 0.05) + stale := m.getStale() m.lock.Lock() - for _, outlier := range outliers { - delete(m.vectors, outlier) + for _, key := range stale { + delete(m.vectors, key) } m.lock.Unlock() } +func (m *VectorClock[K]) GetTimestamp(processId K) uint64 { + return m.vectors[processId].lastUpdate +} + func (m *VectorClock[K]) Put(key K, value uint64) { + clockValue := uint64(0) + m.lock.Lock() - m.vectors[key] = max(value, m.vectors[key]) + bucket, ok := m.vectors[key] + + if ok { + clockValue = bucket.clock + } + + if value > clockValue { + newBucket := VectorBucket{ + clock: value, + lastUpdate: uint64(time.Now().Unix()), + } + m.vectors[key] = &newBucket + } + m.lock.Unlock() } @@ -64,6 +126,9 @@ func (m *VectorClock[K]) GetClock() map[K]uint64 { m.lock.RLock() + keys := lib.MapKeys(m.vectors) + slices.Sort(keys) + for key, value := range clock { clock[key] = value } @@ -72,9 +137,11 @@ func (m *VectorClock[K]) GetClock() map[K]uint64 { return clock } -func NewVectorClock[K comparable](processID K) *VectorClock[K] { +func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] { return &VectorClock[K]{ - vectors: make(map[K]uint64), + vectors: make(map[K]*VectorBucket), processID: processID, + staleTime: staleTime, + hashFunc: hashFunc, } } diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index b540088..053ef33 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -1,11 +1,13 @@ package lib +import "cmp" + // MapToSlice converts a map to a slice in go -func MapValues[K comparable, V any](m map[K]V) []V { +func MapValues[K cmp.Ordered, V any](m map[K]V) []V { return MapValuesWithExclude(m, map[K]struct{}{}) } -func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V { +func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V { values := make([]V, len(m)-len(exclude)) i := 0 @@ -26,7 +28,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{} return values } -func MapKeys[K comparable, V any](m map[K]V) []K { +func MapKeys[K cmp.Ordered, V any](m map[K]V) []K { values := make([]K, len(m)) i := 0 diff --git a/pkg/lib/stats.go b/pkg/lib/stats.go index 3aff713..7b04b76 100644 --- a/pkg/lib/stats.go +++ b/pkg/lib/stats.go @@ -2,6 +2,7 @@ package lib import ( + "cmp" "math" "gonum.org/v1/gonum/stat" @@ -10,7 +11,7 @@ import ( // Modelling the distribution using a normal distribution get the count // of the outliers -func GetOutliers[K comparable](counts map[K]uint64, alpha float64) []K { +func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K { n := float64(len(counts)) keys := MapKeys(counts) diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 981404e..7576e87 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -8,7 +8,6 @@ import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" - logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -329,7 +328,6 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { return nil, fmt.Errorf("mesh %s does not exist", meshId) } - logging.Log.WriteInfof(s.HostParameters.GetPublicKey()) node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey()) if err != nil { diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 72556d5..d8e9eb2 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -173,7 +173,7 @@ type MeshProviderFactory interface { // MeshNodeFactoryParams are the parameters required to construct // a mesh node type MeshNodeFactoryParams struct { - PublicKey *wgtypes.Key +PublicKey *wgtypes.Key NodeIP net.IP WgPort int Endpoint string diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index b6aa1b3..075181d 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -12,7 +12,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/mesh" ) -// Syncer: picks random nodes from the mesh +// Syncer: picks random nodes from the meshs type Syncer interface { Sync(meshId string) error SyncMeshes() error @@ -25,54 +25,63 @@ type SyncerImpl struct { syncCount int cluster conn.ConnCluster conf *conf.WgMeshConfiguration + lastSync uint64 } // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { - if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { + self, err := s.manager.GetSelf(meshId) + + if err != nil { + return err + } + + s.manager.GetMesh(meshId).Prune() + + if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { logging.Log.WriteInfof("No changes for %s", meshId) return nil } - logging.Log.WriteInfof("UPDATING WG CONF") - + before := time.Now() s.manager.GetRouteManager().UpdateRoutes() - err := s.manager.ApplyConfig() - - if err != nil { - logging.Log.WriteInfof("Failed to update config %w", err) - } publicKey := s.manager.GetPublicKey() logging.Log.WriteInfof(publicKey.String()) nodeNames := s.manager.GetMesh(meshId).GetPeers() - neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) - randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) - for _, node := range randomSubset { - logging.Log.WriteInfof("Random node: %s", node) - } + var gossipNodes []string - before := time.Now() + // Clients always pings its peer for configuration + if self.GetType() == conf.CLIENT_ROLE { + keyFunc := lib.HashString + bucketFunc := lib.HashString - if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { - logging.Log.WriteInfof("Sending to random cluster") - randomSubset[len(randomSubset)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) + neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc) + gossipNodes = make([]string, 1) + gossipNodes[0] = neighbour + } else { + neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) + gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) + + if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { + gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) + } } var succeeded bool = false // Do this synchronously to conserve bandwidth - for _, node := range randomSubset { + for _, node := range gossipNodes { correspondingPeer := s.manager.GetNode(meshId, node) if correspondingPeer == nil { logging.Log.WriteErrorf("node %s does not exist", node) } - err = s.requester.SyncMesh(meshId, correspondingPeer) + err := s.requester.SyncMesh(meshId, correspondingPeer) if err == nil || err == io.EOF { succeeded = true @@ -96,7 +105,15 @@ func (s *SyncerImpl) Sync(meshId string) error { } s.manager.GetMesh(meshId).SaveChanges() - s.manager.Prune() + s.lastSync = uint64(time.Now().Unix()) + + logging.Log.WriteInfof("UPDATING WG CONF") + err = s.manager.ApplyConfig() + + if err != nil { + logging.Log.WriteInfof("Failed to update config %w", err) + } + return nil } diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 9a4396a..03022e0 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -3,7 +3,6 @@ package wg import ( "crypto" "crypto/rand" - "encoding/base64" "fmt" "github.com/tim-beatham/wgmesh/pkg/lib" @@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes. } md5 := crypto.MD5.New().Sum(randomBuf) - - md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength]) + md5Str := fmt.Sprintf("wg%x", md5)[:hashLength] err = rtnl.CreateLink(md5Str)