45-use-statistical-testing

Using statistical testing to test whether the node has failed.
This commit is contained in:
Tim Beatham
2023-12-07 01:44:54 +00:00
parent 2169f7796f
commit 64885f1055
14 changed files with 231 additions and 130 deletions

View File

@@ -13,7 +13,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@@ -63,8 +62,6 @@ func main() {
syncRequester = sync.NewSyncRequester(ctrlServer) syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@@ -82,13 +79,10 @@ func main() {
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop() syncScheduler.Stop()
timestampScheduler.Stop()
ctrlServer.Close() ctrlServer.Close()
client.Close() client.Close()
} }

View File

@@ -477,54 +477,54 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
func (m *CrdtMeshManager) Prune(pruneTime int) error { func (m *CrdtMeshManager) Prune() error {
nodes, err := m.doc.Path("nodes").Get() // nodes, err := m.doc.Path("nodes").Get()
if err != nil { // if err != nil {
return err // return err
} // }
if nodes.Kind() != automerge.KindMap { // if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map") // return errors.New("node must be a map")
} // }
values, err := nodes.Map().Values() // values, err := nodes.Map().Values()
if err != nil { // if err != nil {
return err // return err
} // }
deletionNodes := make([]string, 0) // deletionNodes := make([]string, 0)
for nodeId, node := range values { // for nodeId, node := range values {
if node.Kind() != automerge.KindMap { // if node.Kind() != automerge.KindMap {
return errors.New("node must be a map") // return errors.New("node must be a map")
} // }
nodeMap := node.Map() // nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp") // timeStamp, err := nodeMap.Get("timestamp")
if err != nil { // if err != nil {
return err // return err
} // }
if timeStamp.Kind() != automerge.KindInt64 { // if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64") // return errors.New("timestamp is not int64")
} // }
timeValue := timeStamp.Int64() // timeValue := timeStamp.Int64()
nowValue := time.Now().Unix() // nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) { // if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId) // deletionNodes = append(deletionNodes, nodeId)
} // }
} // }
for _, node := range deletionNodes { // for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node) // logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node) // nodes.Map().Delete(node)
} // }
return nil return nil
} }

View File

@@ -375,18 +375,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
} }
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount of seconds func (m *TwoPhaseStoreMeshManager) Prune() error {
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error { m.store.Prune()
nodes := lib.MapValues(m.store.AsMap())
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
return time.Now().Unix()-mn.Timestamp > int64(pruneAmount)
})
for _, node := range nodes {
key, _ := node.GetPublicKey()
m.store.Remove(key.String())
}
return nil return nil
} }
@@ -405,7 +395,7 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
return false return false
} }
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime) return true
}) })
return lib.Map(nodes, func(mn MeshNode) string { return lib.Map(nodes, func(mn MeshNode) string {

View File

@@ -3,6 +3,8 @@ package crdt
import ( import (
"sync" "sync"
"github.com/tim-beatham/wgmesh/pkg/lib"
) )
type Bucket[D any] struct { type Bucket[D any] struct {
@@ -15,13 +17,13 @@ type Bucket[D any] struct {
type GMap[K comparable, D any] struct { type GMap[K comparable, D any] struct {
lock sync.RWMutex lock sync.RWMutex
contents map[K]Bucket[D] contents map[K]Bucket[D]
getClock func() uint64 clock *VectorClock[K]
} }
func (g *GMap[K, D]) Put(key K, value D) { func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock() g.lock.Lock()
clock := g.getClock() + 1 clock := g.clock.IncrementClock()
g.contents[key] = Bucket[D]{ g.contents[key] = Bucket[D]{
Vector: clock, Vector: clock,
@@ -67,6 +69,7 @@ func (g *GMap[K, D]) Mark(key K) {
g.lock.Lock() g.lock.Lock()
bucket := g.contents[key] bucket := g.contents[key]
bucket.Gravestone = true bucket.Gravestone = true
g.contents[key] = bucket
g.lock.Unlock() g.lock.Unlock()
} }
@@ -138,9 +141,34 @@ func (g *GMap[K, D]) GetClock() map[K]uint64 {
return clock return clock
} }
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] { func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
outliers := lib.GetOutliers(g.clock.GetClock(), 0.05)
g.lock.Lock()
for _, outlier := range outliers {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K comparable, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{ return &GMap[K, D]{
contents: make(map[K]Bucket[D]), contents: make(map[K]Bucket[D]),
getClock: getClock, clock: clock,
} }
} }

View File

@@ -1,17 +1,14 @@
package crdt package crdt
import ( import (
"sync"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
) )
type TwoPhaseMap[K comparable, D any] struct { type TwoPhaseMap[K comparable, D any] struct {
addMap *GMap[K, D] addMap *GMap[K, D]
removeMap *GMap[K, bool] removeMap *GMap[K, bool]
vectors map[K]uint64 Clock *VectorClock[K]
processId K processId K
lock sync.RWMutex
} }
type TwoPhaseMapSnapshot[K comparable, D any] struct { type TwoPhaseMapSnapshot[K comparable, D any] struct {
@@ -48,15 +45,8 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
// Put places the key K in the map // Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) { func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.incrementClock() msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.lock.Lock()
if _, ok := m.vectors[key]; !ok {
m.vectors[key] = msgSequence
}
m.lock.Unlock()
m.addMap.Put(key, data) m.addMap.Put(key, data)
} }
@@ -115,6 +105,7 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
} }
type TwoPhaseMapState[K comparable] struct { type TwoPhaseMapState[K comparable] struct {
Vectors map[K]uint64
AddContents map[K]uint64 AddContents map[K]uint64
RemoveContents map[K]uint64 RemoveContents map[K]uint64
} }
@@ -123,32 +114,11 @@ func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key) return m.addMap.IsMarked(key)
} }
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value)
}
m.vectors[m.processId] = maxClock + 1
m.lock.Unlock()
return maxClock
}
// GetHash: Get the hash of the current state of the map // GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation // Sums the current values of the vectors. Provides good approximation
// of increasing numbers // of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 { func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
m.lock.RLock() return m.addMap.GetHash() + m.removeMap.GetHash()
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
return current + sum
})
m.lock.RUnlock()
return sum
} }
// GetState: get the current vector clock of the add and remove // GetState: get the current vector clock of the add and remove
@@ -158,11 +128,18 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
removeContents := m.removeMap.GetClock() removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{ return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents, AddContents: addContents,
RemoveContents: removeContents, RemoveContents: removeContents,
} }
} }
func (m *TwoPhaseMap[K, D]) UpdateVector(state *TwoPhaseMapState[K]) {
for key, value := range state.Vectors {
m.Clock.Put(key, value)
}
}
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] { func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{ mapState := &TwoPhaseMapState[K]{
AddContents: make(map[K]uint64), AddContents: make(map[K]uint64),
@@ -189,23 +166,23 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
} }
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
m.lock.Lock()
for key, value := range snapshot.Add { for key, value := range snapshot.Add {
// Gravestone is local only to that node. // Gravestone is local only to that node.
// Discover ourselves if the node is alive // Discover ourselves if the node is alive
value.Gravestone = false
m.addMap.put(key, value) m.addMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key]) m.Clock.Put(key, value.Vector)
} }
for key, value := range snapshot.Remove { for key, value := range snapshot.Remove {
value.Gravestone = false
m.removeMap.put(key, value) m.removeMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key]) m.Clock.Put(key, value.Vector)
}
} }
m.lock.Unlock() func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
} }
// NewTwoPhaseMap: create a new two phase map. Consists of two maps // NewTwoPhaseMap: create a new two phase map. Consists of two maps
@@ -213,11 +190,11 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
// it in the map // it in the map
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] { func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{ m := TwoPhaseMap[K, D]{
vectors: make(map[K]uint64),
processId: processId, processId: processId,
Clock: NewVectorClock(processId),
} }
m.addMap = NewGMap[K, D](m.incrementClock) m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.incrementClock) m.removeMap = NewGMap[K, bool](m.Clock)
return &m return &m
} }

View File

@@ -100,7 +100,6 @@ func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
dec.Decode(&snapshot) dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot) syncer.manager.store.Merge(snapshot)
return nil, false return nil, false
} }
@@ -125,6 +124,7 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
func (t *TwoPhaseSyncer) Complete() { func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED") logging.Log.WriteInfof("SYNC COMPLETED")
t.manager.store.Clock.IncrementClock()
} }
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer { func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {

80
pkg/crdt/vector_clock.go Normal file
View File

@@ -0,0 +1,80 @@
package crdt
import (
"sync"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K comparable] struct {
vectors map[K]uint64
lock sync.RWMutex
processID K
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value)
}
m.vectors[m.processID] = maxClock + 1
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
return current + sum
})
m.lock.RUnlock()
return sum
}
func (m *VectorClock[K]) Prune() {
outliers := lib.GetOutliers(m.vectors, 0.05)
m.lock.Lock()
for _, outlier := range outliers {
delete(m.vectors, outlier)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.lock.Lock()
m.vectors[key] = max(value, m.vectors[key])
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[K]uint64 {
clock := make(map[K]uint64)
m.lock.RLock()
for key, value := range clock {
clock[key] = value
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K comparable](processID K) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[K]uint64),
processID: processID,
}
}

39
pkg/lib/stats.go Normal file
View File

@@ -0,0 +1,39 @@
// lib contains helper functions for the implementation
package lib
import (
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K comparable](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

View File

@@ -112,7 +112,7 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
// Prune implements MeshManager. // Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error { func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime) err := mesh.Prune()
if err != nil { if err != nil {
return err return err

View File

@@ -131,7 +131,7 @@ func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error { func (*MeshProviderStub) Prune() error {
return nil return nil
} }

View File

@@ -131,9 +131,9 @@ type MeshProvider interface {
AddService(nodeId, key, value string) error AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist // RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their
// pruneAmount seconds // vector clock
Prune(pruneAmount int) error Prune() error
// GetPeers: get a list of contactable peers // GetPeers: get a list of contactable peers
GetPeers() []string GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen // GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen

View File

@@ -72,7 +72,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteErrorf("node %s does not exist", node) logging.Log.WriteErrorf("node %s does not exist", node)
} }
err = s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint()) err = s.requester.SyncMesh(meshId, correspondingPeer)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
succeeded = true succeeded = true
@@ -96,6 +96,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
s.manager.GetMesh(meshId).SaveChanges() s.manager.GetMesh(meshId).SaveChanges()
s.manager.Prune()
return nil return nil
} }

View File

@@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
} }
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId) mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId)
if mesh == nil {
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true return true
} }
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint) return s.handleFailed(meshId, nodeId)
} }
return false return false

View File

@@ -15,7 +15,7 @@ import (
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error SyncMesh(meshid string, meshNode mesh.MeshNode) error
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
@@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
return err return err
} }
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error { func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err) ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok { if ok {
return nil return nil
@@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
} }
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
@@ -96,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, endpoint, err) return s.handleErr(meshId, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)