forked from extern/smegmesh
Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
This commit is contained in:
commit
92c0805275
@ -63,8 +63,7 @@ func main() {
|
|||||||
syncRequester = sync.NewSyncRequester(ctrlServer)
|
syncRequester = sync.NewSyncRequester(ctrlServer)
|
||||||
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
|
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
|
||||||
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
|
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
|
||||||
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
|
keepAlive := timer.NewTimestampScheduler(ctrlServer)
|
||||||
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
|
|
||||||
|
|
||||||
robinIpcParams := robin.RobinIpcParams{
|
robinIpcParams := robin.RobinIpcParams{
|
||||||
CtrlServer: ctrlServer,
|
CtrlServer: ctrlServer,
|
||||||
@ -82,13 +81,12 @@ func main() {
|
|||||||
|
|
||||||
go ipc.RunIpcHandler(&robinIpc)
|
go ipc.RunIpcHandler(&robinIpc)
|
||||||
go syncScheduler.Run()
|
go syncScheduler.Run()
|
||||||
go timestampScheduler.Run()
|
go keepAlive.Run()
|
||||||
go pruneScheduler.Run()
|
|
||||||
|
|
||||||
closeResources := func() {
|
closeResources := func() {
|
||||||
logging.Log.WriteInfof("Closing resources")
|
logging.Log.WriteInfof("Closing resources")
|
||||||
syncScheduler.Stop()
|
syncScheduler.Stop()
|
||||||
timestampScheduler.Stop()
|
keepAlive.Stop()
|
||||||
ctrlServer.Close()
|
ctrlServer.Close()
|
||||||
client.Close()
|
client.Close()
|
||||||
}
|
}
|
||||||
|
@ -477,54 +477,54 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
|||||||
return NewAutomergeSync(m)
|
return NewAutomergeSync(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
func (m *CrdtMeshManager) Prune() error {
|
||||||
nodes, err := m.doc.Path("nodes").Get()
|
// nodes, err := m.doc.Path("nodes").Get()
|
||||||
|
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
if nodes.Kind() != automerge.KindMap {
|
// if nodes.Kind() != automerge.KindMap {
|
||||||
return errors.New("node must be a map")
|
// return errors.New("node must be a map")
|
||||||
}
|
// }
|
||||||
|
|
||||||
values, err := nodes.Map().Values()
|
// values, err := nodes.Map().Values()
|
||||||
|
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
deletionNodes := make([]string, 0)
|
// deletionNodes := make([]string, 0)
|
||||||
|
|
||||||
for nodeId, node := range values {
|
// for nodeId, node := range values {
|
||||||
if node.Kind() != automerge.KindMap {
|
// if node.Kind() != automerge.KindMap {
|
||||||
return errors.New("node must be a map")
|
// return errors.New("node must be a map")
|
||||||
}
|
// }
|
||||||
|
|
||||||
nodeMap := node.Map()
|
// nodeMap := node.Map()
|
||||||
|
|
||||||
timeStamp, err := nodeMap.Get("timestamp")
|
// timeStamp, err := nodeMap.Get("timestamp")
|
||||||
|
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
if timeStamp.Kind() != automerge.KindInt64 {
|
// if timeStamp.Kind() != automerge.KindInt64 {
|
||||||
return errors.New("timestamp is not int64")
|
// return errors.New("timestamp is not int64")
|
||||||
}
|
// }
|
||||||
|
|
||||||
timeValue := timeStamp.Int64()
|
// timeValue := timeStamp.Int64()
|
||||||
nowValue := time.Now().Unix()
|
// nowValue := time.Now().Unix()
|
||||||
|
|
||||||
if nowValue-timeValue >= int64(pruneTime) {
|
// if nowValue-timeValue >= int64(pruneTime) {
|
||||||
deletionNodes = append(deletionNodes, nodeId)
|
// deletionNodes = append(deletionNodes, nodeId)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
for _, node := range deletionNodes {
|
// for _, node := range deletionNodes {
|
||||||
logging.Log.WriteInfof("Pruning %s", node)
|
// logging.Log.WriteInfof("Pruning %s", node)
|
||||||
nodes.Map().Delete(node)
|
// nodes.Map().Delete(node)
|
||||||
}
|
// }
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -245,6 +246,31 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
|
|||||||
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort nodes by their public key
|
||||||
|
peers := m.GetPeers()
|
||||||
|
slices.Sort(peers)
|
||||||
|
|
||||||
|
if len(peers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerToUpdate := peers[0]
|
||||||
|
|
||||||
|
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) {
|
||||||
|
m.store.Mark(peerToUpdate)
|
||||||
|
|
||||||
|
if len(peers) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerToUpdate = peers[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerToUpdate != nodeId {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh causing node to update it's time stamp
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
node.Timestamp = time.Now().Unix()
|
node.Timestamp = time.Now().Unix()
|
||||||
m.store.Put(nodeId, node)
|
m.store.Put(nodeId, node)
|
||||||
@ -375,18 +401,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prune: prunes all nodes that have not updated their timestamp in
|
// Prune: prunes all nodes that have not updated their timestamp in
|
||||||
// pruneAmount of seconds
|
func (m *TwoPhaseStoreMeshManager) Prune() error {
|
||||||
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
|
m.store.Prune()
|
||||||
nodes := lib.MapValues(m.store.AsMap())
|
|
||||||
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
|
||||||
return time.Now().Unix()-mn.Timestamp > int64(pruneAmount)
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, node := range nodes {
|
|
||||||
key, _ := node.GetPublicKey()
|
|
||||||
m.store.Remove(key.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,7 +421,7 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
return lib.Map(nodes, func(mn MeshNode) string {
|
return lib.Map(nodes, func(mn MeshNode) string {
|
||||||
|
@ -2,6 +2,7 @@ package crdt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
@ -16,7 +17,11 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
|
|||||||
IfName: params.DevName,
|
IfName: params.DevName,
|
||||||
Client: params.Client,
|
Client: params.Client,
|
||||||
conf: params.Conf,
|
conf: params.Conf,
|
||||||
store: NewTwoPhaseMap[string, MeshNode](params.NodeID),
|
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
|
||||||
|
h := fnv.New32a()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
return uint64(h.Sum32())
|
||||||
|
}, uint64(3*params.Conf.KeepAliveTime)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
package crdt
|
package crdt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,16 +13,16 @@ type Bucket[D any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GMap is a set that can only grow in size
|
// GMap is a set that can only grow in size
|
||||||
type GMap[K comparable, D any] struct {
|
type GMap[K cmp.Ordered, D any] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
contents map[K]Bucket[D]
|
contents map[K]Bucket[D]
|
||||||
getClock func() uint64
|
clock *VectorClock[K]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Put(key K, value D) {
|
func (g *GMap[K, D]) Put(key K, value D) {
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
|
|
||||||
clock := g.getClock() + 1
|
clock := g.clock.IncrementClock()
|
||||||
|
|
||||||
g.contents[key] = Bucket[D]{
|
g.contents[key] = Bucket[D]{
|
||||||
Vector: clock,
|
Vector: clock,
|
||||||
@ -67,6 +68,7 @@ func (g *GMap[K, D]) Mark(key K) {
|
|||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
bucket := g.contents[key]
|
bucket := g.contents[key]
|
||||||
bucket.Gravestone = true
|
bucket.Gravestone = true
|
||||||
|
g.contents[key] = bucket
|
||||||
g.lock.Unlock()
|
g.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,9 +140,33 @@ func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
|||||||
return clock
|
return clock
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] {
|
func (g *GMap[K, D]) GetHash() uint64 {
|
||||||
|
hash := uint64(0)
|
||||||
|
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
for _, value := range g.contents {
|
||||||
|
hash += value.Vector
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Prune() {
|
||||||
|
stale := g.clock.getStale()
|
||||||
|
g.lock.Lock()
|
||||||
|
|
||||||
|
for _, outlier := range stale {
|
||||||
|
delete(g.contents, outlier)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
|
||||||
return &GMap[K, D]{
|
return &GMap[K, D]{
|
||||||
contents: make(map[K]Bucket[D]),
|
contents: make(map[K]Bucket[D]),
|
||||||
getClock: getClock,
|
clock: clock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,20 +1,19 @@
|
|||||||
package crdt
|
package crdt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"cmp"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TwoPhaseMap[K comparable, D any] struct {
|
type TwoPhaseMap[K cmp.Ordered, D any] struct {
|
||||||
addMap *GMap[K, D]
|
addMap *GMap[K, D]
|
||||||
removeMap *GMap[K, bool]
|
removeMap *GMap[K, bool]
|
||||||
vectors map[K]uint64
|
Clock *VectorClock[K]
|
||||||
processId K
|
processId K
|
||||||
lock sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapSnapshot[K comparable, D any] struct {
|
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
|
||||||
Add map[K]Bucket[D]
|
Add map[K]Bucket[D]
|
||||||
Remove map[K]Bucket[bool]
|
Remove map[K]Bucket[bool]
|
||||||
}
|
}
|
||||||
@ -48,15 +47,8 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
|
|||||||
|
|
||||||
// Put places the key K in the map
|
// Put places the key K in the map
|
||||||
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||||
msgSequence := m.incrementClock()
|
msgSequence := m.Clock.IncrementClock()
|
||||||
|
m.Clock.Put(key, msgSequence)
|
||||||
m.lock.Lock()
|
|
||||||
|
|
||||||
if _, ok := m.vectors[key]; !ok {
|
|
||||||
m.vectors[key] = msgSequence
|
|
||||||
}
|
|
||||||
|
|
||||||
m.lock.Unlock()
|
|
||||||
m.addMap.Put(key, data)
|
m.addMap.Put(key, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +106,8 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapState[K comparable] struct {
|
type TwoPhaseMapState[K cmp.Ordered] struct {
|
||||||
|
Vectors map[K]uint64
|
||||||
AddContents map[K]uint64
|
AddContents map[K]uint64
|
||||||
RemoveContents map[K]uint64
|
RemoveContents map[K]uint64
|
||||||
}
|
}
|
||||||
@ -123,32 +116,11 @@ func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
|||||||
return m.addMap.IsMarked(key)
|
return m.addMap.IsMarked(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
|
|
||||||
maxClock := uint64(0)
|
|
||||||
m.lock.Lock()
|
|
||||||
|
|
||||||
for _, value := range m.vectors {
|
|
||||||
maxClock = max(maxClock, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.vectors[m.processId] = maxClock + 1
|
|
||||||
m.lock.Unlock()
|
|
||||||
return maxClock
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHash: Get the hash of the current state of the map
|
// GetHash: Get the hash of the current state of the map
|
||||||
// Sums the current values of the vectors. Provides good approximation
|
// Sums the current values of the vectors. Provides good approximation
|
||||||
// of increasing numbers
|
// of increasing numbers
|
||||||
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
||||||
m.lock.RLock()
|
return m.addMap.GetHash() + m.removeMap.GetHash()
|
||||||
|
|
||||||
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
|
|
||||||
return current + sum
|
|
||||||
})
|
|
||||||
|
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
return sum
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetState: get the current vector clock of the add and remove
|
// GetState: get the current vector clock of the add and remove
|
||||||
@ -158,11 +130,18 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
|||||||
removeContents := m.removeMap.GetClock()
|
removeContents := m.removeMap.GetClock()
|
||||||
|
|
||||||
return &TwoPhaseMapState[K]{
|
return &TwoPhaseMapState[K]{
|
||||||
|
Vectors: m.Clock.GetClock(),
|
||||||
AddContents: addContents,
|
AddContents: addContents,
|
||||||
RemoveContents: removeContents,
|
RemoveContents: removeContents,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) UpdateVector(state *TwoPhaseMapState[K]) {
|
||||||
|
for key, value := range state.Vectors {
|
||||||
|
m.Clock.Put(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
||||||
mapState := &TwoPhaseMapState[K]{
|
mapState := &TwoPhaseMapState[K]{
|
||||||
AddContents: make(map[K]uint64),
|
AddContents: make(map[K]uint64),
|
||||||
@ -177,7 +156,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range state.AddContents {
|
for key, value := range state.RemoveContents {
|
||||||
otherValue, ok := m.RemoveContents[key]
|
otherValue, ok := m.RemoveContents[key]
|
||||||
|
|
||||||
if !ok || otherValue < value {
|
if !ok || otherValue < value {
|
||||||
@ -189,35 +168,35 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
||||||
m.lock.Lock()
|
|
||||||
|
|
||||||
for key, value := range snapshot.Add {
|
for key, value := range snapshot.Add {
|
||||||
// Gravestone is local only to that node.
|
// Gravestone is local only to that node.
|
||||||
// Discover ourselves if the node is alive
|
// Discover ourselves if the node is alive
|
||||||
value.Gravestone = false
|
|
||||||
m.addMap.put(key, value)
|
m.addMap.put(key, value)
|
||||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
m.Clock.Put(key, value.Vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range snapshot.Remove {
|
for key, value := range snapshot.Remove {
|
||||||
value.Gravestone = false
|
|
||||||
m.removeMap.put(key, value)
|
m.removeMap.put(key, value)
|
||||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
m.Clock.Put(key, value.Vector)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.lock.Unlock()
|
func (m *TwoPhaseMap[K, D]) Prune() {
|
||||||
|
m.addMap.Prune()
|
||||||
|
m.removeMap.Prune()
|
||||||
|
m.Clock.Prune()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
|
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
|
||||||
// a grow map and a remove map. If both timestamps equal then favour keeping
|
// a grow map and a remove map. If both timestamps equal then favour keeping
|
||||||
// it in the map
|
// it in the map
|
||||||
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
|
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
|
||||||
m := TwoPhaseMap[K, D]{
|
m := TwoPhaseMap[K, D]{
|
||||||
vectors: make(map[K]uint64),
|
|
||||||
processId: processId,
|
processId: processId,
|
||||||
|
Clock: NewVectorClock(processId, hashKey, staleTime),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMap = NewGMap[K, D](m.incrementClock)
|
m.addMap = NewGMap[K, D](m.Clock)
|
||||||
m.removeMap = NewGMap[K, bool](m.incrementClock)
|
m.removeMap = NewGMap[K, bool](m.Clock)
|
||||||
return &m
|
return &m
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,8 @@ import (
|
|||||||
type SyncState int
|
type SyncState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PREPARE SyncState = iota
|
HASH SyncState = iota
|
||||||
|
PREPARE
|
||||||
PRESENT
|
PRESENT
|
||||||
EXCHANGE
|
EXCHANGE
|
||||||
MERGE
|
MERGE
|
||||||
@ -26,13 +27,51 @@ type TwoPhaseSyncer struct {
|
|||||||
peerMsg []byte
|
peerMsg []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TwoPhaseHash struct {
|
||||||
|
Hash uint64
|
||||||
|
}
|
||||||
|
|
||||||
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
|
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
|
||||||
|
|
||||||
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
hash := TwoPhaseHash{
|
||||||
|
Hash: syncer.manager.store.Clock.GetHash(),
|
||||||
|
}
|
||||||
|
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
enc := gob.NewEncoder(&buffer)
|
enc := gob.NewEncoder(&buffer)
|
||||||
|
|
||||||
err := enc.Encode(*syncer.mapState)
|
err := enc.Encode(hash)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
syncer.IncrementState()
|
||||||
|
return buffer.Bytes(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
|
||||||
|
dec := gob.NewDecoder(recvBuffer)
|
||||||
|
|
||||||
|
var hash TwoPhaseHash
|
||||||
|
err := dec.Decode(&hash)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// If vector clocks are equal then no need to merge state
|
||||||
|
// Helps to reduce bandwidth by detecting early
|
||||||
|
if hash.Hash == syncer.manager.store.Clock.GetHash() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buffer)
|
||||||
|
|
||||||
|
err = enc.Encode(*syncer.mapState)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteInfof(err.Error())
|
||||||
@ -100,7 +139,6 @@ func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
dec.Decode(&snapshot)
|
dec.Decode(&snapshot)
|
||||||
|
|
||||||
syncer.manager.store.Merge(snapshot)
|
syncer.manager.store.Merge(snapshot)
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,10 +163,14 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
|||||||
|
|
||||||
func (t *TwoPhaseSyncer) Complete() {
|
func (t *TwoPhaseSyncer) Complete() {
|
||||||
logging.Log.WriteInfof("SYNC COMPLETED")
|
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||||
|
if t.state == FINISHED || t.state == MERGE {
|
||||||
|
t.manager.store.Clock.IncrementClock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
||||||
var generateMessageFsm SyncFSM = SyncFSM{
|
var generateMessageFsm SyncFSM = SyncFSM{
|
||||||
|
HASH: hash,
|
||||||
PREPARE: prepare,
|
PREPARE: prepare,
|
||||||
PRESENT: present,
|
PRESENT: present,
|
||||||
EXCHANGE: exchange,
|
EXCHANGE: exchange,
|
||||||
@ -137,7 +179,7 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
|||||||
|
|
||||||
return &TwoPhaseSyncer{
|
return &TwoPhaseSyncer{
|
||||||
manager: manager,
|
manager: manager,
|
||||||
state: PREPARE,
|
state: HASH,
|
||||||
mapState: manager.store.GenerateMessage(),
|
mapState: manager.store.GenerateMessage(),
|
||||||
generateMessageFSM: generateMessageFsm,
|
generateMessageFSM: generateMessageFsm,
|
||||||
}
|
}
|
||||||
|
147
pkg/crdt/vector_clock.go
Normal file
147
pkg/crdt/vector_clock.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VectorBucket struct {
|
||||||
|
// clock current value of the node's clock
|
||||||
|
clock uint64
|
||||||
|
// lastUpdate we've seen
|
||||||
|
lastUpdate uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector clock defines an abstract data type
|
||||||
|
// for a vector clock implementation
|
||||||
|
type VectorClock[K cmp.Ordered] struct {
|
||||||
|
vectors map[K]*VectorBucket
|
||||||
|
lock sync.RWMutex
|
||||||
|
processID K
|
||||||
|
staleTime uint64
|
||||||
|
hashFunc func(K) uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementClock: increments the node's value in the vector clock
|
||||||
|
func (m *VectorClock[K]) IncrementClock() uint64 {
|
||||||
|
maxClock := uint64(0)
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for _, value := range m.vectors {
|
||||||
|
maxClock = max(maxClock, value.clock)
|
||||||
|
}
|
||||||
|
|
||||||
|
newBucket := VectorBucket{
|
||||||
|
clock: maxClock + 1,
|
||||||
|
lastUpdate: uint64(time.Now().Unix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.vectors[m.processID] = &newBucket
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
return maxClock
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHash: gets the hash of the vector clock used to determine if there
|
||||||
|
// are any changes
|
||||||
|
func (m *VectorClock[K]) GetHash() uint64 {
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
hash := uint64(0)
|
||||||
|
|
||||||
|
sortedKeys := lib.MapKeys(m.vectors)
|
||||||
|
slices.Sort(sortedKeys)
|
||||||
|
|
||||||
|
for key, bucket := range m.vectors {
|
||||||
|
hash += m.hashFunc(key)
|
||||||
|
hash += bucket.clock
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStale: get all entries that are stale within the mesh
|
||||||
|
func (m *VectorClock[K]) getStale() []K {
|
||||||
|
m.lock.RLock()
|
||||||
|
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
|
||||||
|
return max(i, vb.lastUpdate)
|
||||||
|
})
|
||||||
|
|
||||||
|
toRemove := make([]K, 0)
|
||||||
|
|
||||||
|
for key, bucket := range m.vectors {
|
||||||
|
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return toRemove
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Prune() {
|
||||||
|
stale := m.getStale()
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for _, key := range stale {
|
||||||
|
delete(m.vectors, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
||||||
|
return m.vectors[processId].lastUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||||
|
clockValue := uint64(0)
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
bucket, ok := m.vectors[key]
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
clockValue = bucket.clock
|
||||||
|
}
|
||||||
|
|
||||||
|
if value > clockValue {
|
||||||
|
newBucket := VectorBucket{
|
||||||
|
clock: value,
|
||||||
|
lastUpdate: uint64(time.Now().Unix()),
|
||||||
|
}
|
||||||
|
m.vectors[key] = &newBucket
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) GetClock() map[K]uint64 {
|
||||||
|
clock := make(map[K]uint64)
|
||||||
|
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
keys := lib.MapKeys(m.vectors)
|
||||||
|
slices.Sort(keys)
|
||||||
|
|
||||||
|
for key, value := range clock {
|
||||||
|
clock[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return clock
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
|
||||||
|
return &VectorClock[K]{
|
||||||
|
vectors: make(map[K]*VectorBucket),
|
||||||
|
processID: processID,
|
||||||
|
staleTime: staleTime,
|
||||||
|
hashFunc: hashFunc,
|
||||||
|
}
|
||||||
|
}
|
@ -1,11 +1,13 @@
|
|||||||
package lib
|
package lib
|
||||||
|
|
||||||
|
import "cmp"
|
||||||
|
|
||||||
// MapToSlice converts a map to a slice in go
|
// MapToSlice converts a map to a slice in go
|
||||||
func MapValues[K comparable, V any](m map[K]V) []V {
|
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
|
||||||
return MapValuesWithExclude(m, map[K]struct{}{})
|
return MapValuesWithExclude(m, map[K]struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
|
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
||||||
values := make([]V, len(m)-len(exclude))
|
values := make([]V, len(m)-len(exclude))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
@ -26,7 +28,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
|
|||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapKeys[K comparable, V any](m map[K]V) []K {
|
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
|
||||||
values := make([]K, len(m))
|
values := make([]K, len(m))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
|
40
pkg/lib/stats.go
Normal file
40
pkg/lib/stats.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// lib contains helper functions for the implementation
|
||||||
|
package lib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/stat"
|
||||||
|
"gonum.org/v1/gonum/stat/distuv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Modelling the distribution using a normal distribution get the count
|
||||||
|
// of the outliers
|
||||||
|
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
|
||||||
|
n := float64(len(counts))
|
||||||
|
|
||||||
|
keys := MapKeys(counts)
|
||||||
|
values := make([]float64, len(keys))
|
||||||
|
|
||||||
|
for index, key := range keys {
|
||||||
|
values[index] = float64(counts[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
mean := stat.Mean(values, nil)
|
||||||
|
stdDev := stat.StdDev(values, nil)
|
||||||
|
|
||||||
|
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
|
||||||
|
|
||||||
|
lowerBound := mean - moe
|
||||||
|
|
||||||
|
var outliers []K
|
||||||
|
|
||||||
|
for i, count := range values {
|
||||||
|
if count < lowerBound {
|
||||||
|
outliers = append(outliers, keys[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outliers
|
||||||
|
}
|
@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@ -112,7 +111,7 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
|
|||||||
// Prune implements MeshManager.
|
// Prune implements MeshManager.
|
||||||
func (m *MeshManagerImpl) Prune() error {
|
func (m *MeshManagerImpl) Prune() error {
|
||||||
for _, mesh := range m.Meshes {
|
for _, mesh := range m.Meshes {
|
||||||
err := mesh.Prune(m.conf.PruneTime)
|
err := mesh.Prune()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -329,7 +328,6 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
|
|||||||
return nil, fmt.Errorf("mesh %s does not exist", meshId)
|
return nil, fmt.Errorf("mesh %s does not exist", meshId)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
|
|
||||||
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
|
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -131,7 +131,7 @@ func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prune implements MeshProvider.
|
// Prune implements MeshProvider.
|
||||||
func (*MeshProviderStub) Prune(pruneAmount int) error {
|
func (*MeshProviderStub) Prune() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,9 +131,9 @@ type MeshProvider interface {
|
|||||||
AddService(nodeId, key, value string) error
|
AddService(nodeId, key, value string) error
|
||||||
// RemoveService: removes the service form the node. throws an error if the service does not exist
|
// RemoveService: removes the service form the node. throws an error if the service does not exist
|
||||||
RemoveService(nodeId, key string) error
|
RemoveService(nodeId, key string) error
|
||||||
// Prune: prunes all nodes that have not updated their timestamp in
|
// Prune: prunes all nodes that have not updated their
|
||||||
// pruneAmount seconds
|
// vector clock
|
||||||
Prune(pruneAmount int) error
|
Prune() error
|
||||||
// GetPeers: get a list of contactable peers
|
// GetPeers: get a list of contactable peers
|
||||||
GetPeers() []string
|
GetPeers() []string
|
||||||
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
||||||
@ -173,7 +173,7 @@ type MeshProviderFactory interface {
|
|||||||
// MeshNodeFactoryParams are the parameters required to construct
|
// MeshNodeFactoryParams are the parameters required to construct
|
||||||
// a mesh node
|
// a mesh node
|
||||||
type MeshNodeFactoryParams struct {
|
type MeshNodeFactoryParams struct {
|
||||||
PublicKey *wgtypes.Key
|
PublicKey *wgtypes.Key
|
||||||
NodeIP net.IP
|
NodeIP net.IP
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Syncer: picks random nodes from the mesh
|
// Syncer: picks random nodes from the meshs
|
||||||
type Syncer interface {
|
type Syncer interface {
|
||||||
Sync(meshId string) error
|
Sync(meshId string) error
|
||||||
SyncMeshes() error
|
SyncMeshes() error
|
||||||
@ -25,54 +25,63 @@ type SyncerImpl struct {
|
|||||||
syncCount int
|
syncCount int
|
||||||
cluster conn.ConnCluster
|
cluster conn.ConnCluster
|
||||||
conf *conf.WgMeshConfiguration
|
conf *conf.WgMeshConfiguration
|
||||||
|
lastSync uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync: Sync random nodes
|
// Sync: Sync random nodes
|
||||||
func (s *SyncerImpl) Sync(meshId string) error {
|
func (s *SyncerImpl) Sync(meshId string) error {
|
||||||
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
self, err := s.manager.GetSelf(meshId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.manager.GetMesh(meshId).Prune()
|
||||||
|
|
||||||
|
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
||||||
logging.Log.WriteInfof("No changes for %s", meshId)
|
logging.Log.WriteInfof("No changes for %s", meshId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
before := time.Now()
|
||||||
|
|
||||||
s.manager.GetRouteManager().UpdateRoutes()
|
s.manager.GetRouteManager().UpdateRoutes()
|
||||||
err := s.manager.ApplyConfig()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKey := s.manager.GetPublicKey()
|
publicKey := s.manager.GetPublicKey()
|
||||||
|
|
||||||
logging.Log.WriteInfof(publicKey.String())
|
logging.Log.WriteInfof(publicKey.String())
|
||||||
|
|
||||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
||||||
|
|
||||||
|
var gossipNodes []string
|
||||||
|
|
||||||
|
// Clients always pings its peer for configuration
|
||||||
|
if self.GetType() == conf.CLIENT_ROLE {
|
||||||
|
keyFunc := lib.HashString
|
||||||
|
bucketFunc := lib.HashString
|
||||||
|
|
||||||
|
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
|
||||||
|
gossipNodes = make([]string, 1)
|
||||||
|
gossipNodes[0] = neighbour
|
||||||
|
} else {
|
||||||
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
||||||
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
||||||
|
|
||||||
for _, node := range randomSubset {
|
|
||||||
logging.Log.WriteInfof("Random node: %s", node)
|
|
||||||
}
|
|
||||||
|
|
||||||
before := time.Now()
|
|
||||||
|
|
||||||
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
||||||
logging.Log.WriteInfof("Sending to random cluster")
|
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
||||||
randomSubset[len(randomSubset)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var succeeded bool = false
|
var succeeded bool = false
|
||||||
|
|
||||||
// Do this synchronously to conserve bandwidth
|
// Do this synchronously to conserve bandwidth
|
||||||
for _, node := range randomSubset {
|
for _, node := range gossipNodes {
|
||||||
correspondingPeer := s.manager.GetNode(meshId, node)
|
correspondingPeer := s.manager.GetNode(meshId, node)
|
||||||
|
|
||||||
if correspondingPeer == nil {
|
if correspondingPeer == nil {
|
||||||
logging.Log.WriteErrorf("node %s does not exist", node)
|
logging.Log.WriteErrorf("node %s does not exist", node)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
|
err := s.requester.SyncMesh(meshId, correspondingPeer)
|
||||||
|
|
||||||
if err == nil || err == io.EOF {
|
if err == nil || err == io.EOF {
|
||||||
succeeded = true
|
succeeded = true
|
||||||
@ -96,6 +105,15 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.manager.GetMesh(meshId).SaveChanges()
|
s.manager.GetMesh(meshId).SaveChanges()
|
||||||
|
s.lastSync = uint64(time.Now().Unix())
|
||||||
|
|
||||||
|
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||||
|
err = s.manager.ApplyConfig()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
|
|||||||
meshManager mesh.MeshManager
|
meshManager mesh.MeshManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
|
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
|
||||||
mesh := s.meshManager.GetMesh(meshId)
|
mesh := s.meshManager.GetMesh(meshId)
|
||||||
|
mesh.Mark(nodeId)
|
||||||
if mesh == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// self, err := s.meshManager.GetSelf(meshId)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
|
|
||||||
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
|
||||||
errStatus, _ := status.FromError(err)
|
errStatus, _ := status.FromError(err)
|
||||||
|
|
||||||
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
|
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
|
||||||
|
|
||||||
switch errStatus.Code() {
|
switch errStatus.Code() {
|
||||||
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
||||||
return s.incrementFailedCount(meshId, endpoint)
|
return s.handleFailed(meshId, nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
// SyncRequester: coordinates the syncing of meshes
|
// SyncRequester: coordinates the syncing of meshes
|
||||||
type SyncRequester interface {
|
type SyncRequester interface {
|
||||||
GetMesh(meshId string, ifName string, port int, endPoint string) error
|
GetMesh(meshId string, ifName string, port int, endPoint string) error
|
||||||
SyncMesh(meshid string, endPoint string) error
|
SyncMesh(meshid string, meshNode mesh.MeshNode) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncRequesterImpl struct {
|
type SyncRequesterImpl struct {
|
||||||
@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
|
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
|
||||||
ok := s.errorHdlr.Handle(meshId, endpoint, err)
|
ok := s.errorHdlr.Handle(meshId, pubKey, err)
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
return nil
|
return nil
|
||||||
@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SyncMesh: Proactively send a sync request to the other mesh
|
// SyncMesh: Proactively send a sync request to the other mesh
|
||||||
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
|
||||||
|
endpoint := meshNode.GetHostEndpoint()
|
||||||
|
pubKey, _ := meshNode.GetPublicKey()
|
||||||
|
|
||||||
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -96,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
err = s.syncMesh(mesh, ctx, c)
|
err = s.syncMesh(mesh, ctx, c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.handleErr(meshId, endpoint, err)
|
return s.handleErr(meshId, pubKey.String(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||||
|
@ -3,7 +3,6 @@ package wg
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
md5 := crypto.MD5.New().Sum(randomBuf)
|
md5 := crypto.MD5.New().Sum(randomBuf)
|
||||||
|
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
|
||||||
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
|
|
||||||
|
|
||||||
err = rtnl.CreateLink(md5Str)
|
err = rtnl.CreateLink(md5Str)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user