Compare commits

..

16 Commits

Author SHA1 Message Date
815c4484ee 47-default-routing
Implemented default routing and improved size of gossip. Using 64 bit
hash funciton to identify vector.
2023-12-08 20:02:57 +00:00
0058c9f4c9 47-default-routing
Implementing default routing so that all traffic goes out of an
exit point.
2023-12-08 11:49:24 +00:00
92c0805275 Merge pull request #46 from tim-beatham/45-use-statistical-testing
45 use statistical testing
2023-12-07 18:20:25 +00:00
661fb0d54c 45-use-statistical-testing
Keepalive is based on per mesh and not per node.
Using total ordering mechanism similar to paxos to elect a leader
if leader doesn't update it's timestamp within 3 * keepAlive then
give the leader a gravestone and elect the next leader.
Leader is bassed on lexicographically ordered public key.
2023-12-07 18:18:13 +00:00
64885f1055 45-use-statistical-testing
Using statistical testing to test whether the node has failed.
2023-12-07 01:44:54 +00:00
2169f7796f Merge pull request #44 from tim-beatham/43-gravestones
43-use-gravestones
2023-12-06 22:46:05 +00:00
a3ceff019d 43-use-gravestones
Change of approach from keepalive to a noiseless protocol
2023-12-06 22:45:04 +00:00
b78d96986c Merge pull request #42 from tim-beatham/41-bugfix-fluctuating-ips
41 bugfix fluctuating ips
2023-12-06 14:37:14 +00:00
1b18d89c9f 41-bugfix-fluctuating-ips
Fluctuating ips creating hub and spoke.
2023-12-05 02:00:16 +00:00
245a2c5f58 41-bugfix-fluctuating-ips
If the node is a peer then add the client in the WG
configuration.
2023-12-04 17:40:24 +00:00
c40f7510b8 41-bugfix-fluctuating-ips
IPs of clients fluctuating because there isn't a strict order on
clients. Client's need to be processed before the peers.
2023-12-04 17:32:50 +00:00
78d748770c BUGIX Hash client by public key 2023-12-04 17:13:51 +00:00
0ff2a8eef9 BUGFIX: Allowed IPs fluctuating 2023-12-04 17:11:37 +00:00
fd7bd80485 BUGFIX
Don't get device each time it is an expensive operation.
2023-12-04 16:40:15 +00:00
3ef1b68ba5 BUGFIX: Hashing datastore to work out changes
Changed hashing implementation to work out if there are changes
in the data store
2023-11-30 15:58:26 +00:00
b9ba836ae3 Merge pull request #40 from tim-beatham/39-implement-two-phase-map
39-implement-two-phase-map
2023-11-30 02:03:36 +00:00
23 changed files with 805 additions and 285 deletions

View File

@ -63,8 +63,7 @@ func main() {
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
keepAlive := timer.NewTimestampScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -82,13 +81,12 @@ func main() {
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go keepAlive.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")
syncScheduler.Stop()
timestampScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close()
client.Close()
}

View File

@ -477,54 +477,54 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
func (m *CrdtMeshManager) Prune() error {
// nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
// if err != nil {
// return err
// }
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
// if nodes.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
values, err := nodes.Map().Values()
// values, err := nodes.Map().Values()
if err != nil {
return err
}
// if err != nil {
// return err
// }
deletionNodes := make([]string, 0)
// deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
// for nodeId, node := range values {
// if node.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
nodeMap := node.Map()
// nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
// timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
// if err != nil {
// return err
// }
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
// if timeStamp.Kind() != automerge.KindInt64 {
// return errors.New("timestamp is not int64")
// }
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
// timeValue := timeStamp.Int64()
// nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
// if nowValue-timeValue >= int64(pruneTime) {
// deletionNodes = append(deletionNodes, nodeId)
// }
// }
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
// for _, node := range deletionNodes {
// logging.Log.WriteInfof("Pruning %s", node)
// nodes.Map().Delete(node)
// }
return nil
}

View File

@ -47,6 +47,8 @@ type WgMeshConfiguration struct {
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// AdvertiseDefaultRoute advertises a default route out of the mesh.
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`

View File

@ -5,6 +5,7 @@ import (
"encoding/gob"
"fmt"
"net"
"slices"
"strings"
"time"
@ -48,6 +49,13 @@ type MeshNode struct {
Description string
Services map[string]string
Type string
Tombstone bool
}
// Mark: marks the node is unreachable. This is not broadcast on
// syncrhonisation
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
m.store.Mark(nodeId)
}
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -171,8 +179,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
nodes := m.store.AsList()
snapshot := make(map[string]MeshNode)
for _, node := range nodes {
snapshot[node.PublicKey] = node
}
return &MeshSnapshot{
Nodes: m.store.AsMap(),
Nodes: snapshot,
}, nil
}
@ -200,11 +216,11 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
// Load() loads a mesh network
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
buf := bytes.NewBuffer(bs)
dec := gob.NewDecoder(buf)
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
err := dec.Decode(&snapshot)
m.store.Merge(snapshot)
return err
}
@ -222,13 +238,13 @@ func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
// HasChanges returns true if we have changes since last time we synced
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
clockValue := m.store.GetClock()
clockValue := m.store.GetHash()
return clockValue != m.LastClock
}
// Record that we have changes and save the corresponding changes
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
clockValue := m.store.GetClock()
clockValue := m.store.GetHash()
m.LastClock = clockValue
}
@ -238,6 +254,31 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
// Sort nodes by their public key
peers := m.GetPeers()
slices.Sort(peers)
if len(peers) == 0 {
return nil
}
peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) {
m.store.Mark(peerToUpdate)
if len(peers) < 2 {
return nil
}
peerToUpdate = peers[1]
}
if peerToUpdate != nodeId {
return nil
}
// Refresh causing node to update it's time stamp
node := m.store.Get(nodeId)
node.Timestamp = time.Now().Unix()
m.store.Put(nodeId, node)
@ -256,14 +297,25 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
prevRoute, ok := node.Routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
changes = true
node.Routes[route.GetDestination().String()] = Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
}
}
}
if changes {
m.store.Put(nodeId, node)
}
return nil
}
@ -357,20 +409,27 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
}
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune()
return nil
}
// GetPeers: get a list of contactable peers
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
nodes := lib.MapValues(m.store.AsMap())
nodes := m.store.AsList()
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
if mn.Type != string(conf.PEER_ROLE) {
return false
}
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
// If the node is marked as unreachable don't consider it a peer.
// this help to optimize convergence time for unreachable nodes.
// However advertising it to other nodes could result in flapping.
if m.store.IsMarked(mn.PublicKey) {
return false
}
return true
})
return lib.Map(nodes, func(mn MeshNode) string {

View File

@ -2,6 +2,7 @@ package crdt
import (
"fmt"
"hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
@ -16,7 +17,11 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
IfName: params.DevName,
Client: params.Client,
conf: params.Conf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID),
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}, uint64(3*params.Conf.KeepAliveTime)),
}, nil
}

View File

@ -2,27 +2,29 @@
package crdt
import (
"cmp"
"sync"
)
type Bucket[D any] struct {
Vector uint64
Contents D
Gravestone bool
}
// GMap is a set that can only grow in size
type GMap[K comparable, D any] struct {
type GMap[K cmp.Ordered, D any] struct {
lock sync.RWMutex
contents map[K]Bucket[D]
getClock func() uint64
contents map[uint64]Bucket[D]
clock *VectorClock[K]
}
func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock()
clock := g.getClock() + 1
clock := g.clock.IncrementClock()
g.contents[key] = Bucket[D]{
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
Vector: clock,
Contents: value,
}
@ -31,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) {
}
func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key))
}
func (g *GMap[K, D]) contains(key uint64) bool {
g.lock.RLock()
_, ok := g.contents[key]
@ -40,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool {
return ok
}
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
g.lock.Lock()
if g.contents[key].Vector < b.Vector {
@ -50,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) {
g.lock.Unlock()
}
func (g *GMap[K, D]) get(key K) Bucket[D] {
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
g.lock.RLock()
bucket := g.contents[key]
g.lock.RUnlock()
@ -59,13 +65,38 @@ func (g *GMap[K, D]) get(key K) Bucket[D] {
}
func (g *GMap[K, D]) Get(key K) D {
return g.get(key).Contents
return g.get(g.clock.hashFunc(key)).Contents
}
func (g *GMap[K, D]) Keys() []K {
func (g *GMap[K, D]) Mark(key K) {
g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true
g.contents[g.clock.hashFunc(key)] = bucket
g.lock.Unlock()
}
// IsMarked: returns true if the node is marked
func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false
g.lock.RLock()
contents := make([]K, len(g.contents))
bucket, ok := g.contents[g.clock.hashFunc(key)]
if ok {
marked = bucket.Gravestone
}
g.lock.RUnlock()
return marked
}
func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock()
contents := make([]uint64, len(g.contents))
index := 0
for key := range g.contents {
@ -77,8 +108,8 @@ func (g *GMap[K, D]) Keys() []K {
return contents
}
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
buckets := make(map[K]Bucket[D])
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for key, value := range g.contents {
@ -89,8 +120,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] {
return buckets
}
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
buckets := make(map[K]Bucket[D])
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D])
g.lock.RLock()
for _, key := range keys {
@ -101,8 +132,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
return buckets
}
func (g *GMap[K, D]) GetClock() map[K]uint64 {
clock := make(map[K]uint64)
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
g.lock.RLock()
for key, bucket := range g.contents {
@ -113,9 +144,33 @@ func (g *GMap[K, D]) GetClock() map[K]uint64 {
return clock
}
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] {
func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0)
g.lock.RLock()
for _, value := range g.contents {
hash += value.Vector
}
g.lock.RUnlock()
return hash
}
func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale()
g.lock.Lock()
for _, outlier := range stale {
delete(g.contents, outlier)
}
g.lock.Unlock()
}
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
return &GMap[K, D]{
contents: make(map[K]Bucket[D]),
getClock: getClock,
contents: make(map[uint64]Bucket[D]),
clock: clock,
}
}

View File

@ -1,33 +1,37 @@
package crdt
import (
"sync"
"cmp"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type TwoPhaseMap[K comparable, D any] struct {
type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D]
removeMap *GMap[K, bool]
vectors map[K]uint64
Clock *VectorClock[K]
processId K
lock sync.RWMutex
}
type TwoPhaseMapSnapshot[K comparable, D any] struct {
Add map[K]Bucket[D]
Remove map[K]Bucket[bool]
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
Add map[uint64]Bucket[D]
Remove map[uint64]Bucket[bool]
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
if !m.addMap.Contains(key) {
return m.contains(m.Clock.hashFunc(key))
}
// Contains checks whether the value exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) {
return false
}
addValue := m.addMap.get(key)
if !m.removeMap.Contains(key) {
if !m.removeMap.contains(key) {
return true
}
@ -46,32 +50,39 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
return m.addMap.Get(key)
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.incrementClock()
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
var result D
m.lock.Lock()
if _, ok := m.vectors[key]; !ok {
m.vectors[key] = msgSequence
if !m.contains(key) {
return result
}
m.lock.Unlock()
return m.addMap.get(key).Contents
}
// Put places the key K in the map
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data)
}
func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key)
}
// Remove removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true)
}
func (m *TwoPhaseMap[K, D]) Keys() []K {
keys := make([]K, 0)
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
keys := make([]uint64, 0)
addKeys := m.addMap.Keys()
for _, key := range addKeys {
if !m.Contains(key) {
if !m.contains(key) {
continue
}
@ -81,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K {
return keys
}
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
theMap := make(map[K]D)
func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0)
keys := m.Keys()
keys := m.keys()
for _, key := range keys {
theMap[key] = m.Get(key)
theList = append(theList, m.get(key))
}
return theMap
return theList
}
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
@ -110,34 +121,21 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
}
}
type TwoPhaseMapState[K comparable] struct {
AddContents map[K]uint64
RemoveContents map[K]uint64
type TwoPhaseMapState[K cmp.Ordered] struct {
Vectors map[uint64]uint64
AddContents map[uint64]uint64
RemoveContents map[uint64]uint64
}
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value)
}
m.vectors[m.processId] = maxClock + 1
m.lock.Unlock()
return maxClock
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key)
}
func (m *TwoPhaseMap[K, D]) GetClock() uint64 {
maxClock := uint64(0)
m.lock.RLock()
for _, value := range m.vectors {
maxClock = max(maxClock, value)
}
m.lock.RUnlock()
return maxClock
// GetHash: Get the hash of the current state of the map
// Sums the current values of the vectors. Provides good approximation
// of increasing numbers
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
}
// GetState: get the current vector clock of the add and remove
@ -147,6 +145,7 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
removeContents := m.removeMap.GetClock()
return &TwoPhaseMapState[K]{
Vectors: m.Clock.GetClock(),
AddContents: addContents,
RemoveContents: removeContents,
}
@ -154,8 +153,8 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{
AddContents: make(map[K]uint64),
RemoveContents: make(map[K]uint64),
AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64),
}
for key, value := range state.AddContents {
@ -166,7 +165,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
}
}
for key, value := range state.AddContents {
for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key]
if !ok || otherValue < value {
@ -178,31 +177,35 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
}
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
m.lock.Lock()
for key, value := range snapshot.Add {
// Gravestone is local only to that node.
// Discover ourselves if the node is alive
m.addMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key])
m.Clock.put(key, value.Vector)
}
for key, value := range snapshot.Remove {
m.removeMap.put(key, value)
m.vectors[key] = max(value.Vector, m.vectors[key])
m.Clock.put(key, value.Vector)
}
}
m.lock.Unlock()
func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune()
m.removeMap.Prune()
m.Clock.Prune()
}
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
// a grow map and a remove map. If both timestamps equal then favour keeping
// it in the map
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
m := TwoPhaseMap[K, D]{
vectors: make(map[K]uint64),
processId: processId,
Clock: NewVectorClock(processId, hashKey, staleTime),
}
m.addMap = NewGMap[K, D](m.incrementClock)
m.removeMap = NewGMap[K, bool](m.incrementClock)
m.addMap = NewGMap[K, D](m.Clock)
m.removeMap = NewGMap[K, bool](m.Clock)
return &m
}

View File

@ -10,7 +10,8 @@ import (
type SyncState int
const (
PREPARE SyncState = iota
HASH SyncState = iota
PREPARE
PRESENT
EXCHANGE
MERGE
@ -26,16 +27,54 @@ type TwoPhaseSyncer struct {
peerMsg []byte
}
type TwoPhaseHash struct {
Hash uint64
}
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
hash := TwoPhaseHash{
Hash: syncer.manager.store.Clock.GetHash(),
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(*syncer.mapState)
err := enc.Encode(hash)
if err != nil {
logging.Log.WriteInfof(err.Error())
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
return buffer.Bytes(), true
}
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
dec := gob.NewDecoder(recvBuffer)
var hash TwoPhaseHash
err := dec.Decode(&hash)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
// If vector clocks are equal then no need to merge state
// Helps to reduce bandwidth by detecting early
if hash.Hash == syncer.manager.store.Clock.GetHash() {
return nil, false
}
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err = enc.Encode(*syncer.mapState)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
syncer.IncrementState()
@ -54,10 +93,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
err := dec.Decode(&mapState)
if err != nil {
logging.Log.WriteInfof(err.Error())
logging.Log.WriteErrorf(err.Error())
}
difference := syncer.mapState.Difference(&mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer
enc := gob.NewEncoder(&sendBuffer)
@ -100,7 +140,6 @@ func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
dec.Decode(&snapshot)
syncer.manager.store.Merge(snapshot)
return nil, false
}
@ -125,11 +164,14 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED")
t.manager.SaveChanges()
if t.state >= MERGE {
t.manager.store.Clock.IncrementClock()
}
}
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
var generateMessageFsm SyncFSM = SyncFSM{
HASH: hash,
PREPARE: prepare,
PRESENT: present,
EXCHANGE: exchange,
@ -138,7 +180,7 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
return &TwoPhaseSyncer{
manager: manager,
state: PREPARE,
state: HASH,
mapState: manager.store.GenerateMessage(),
generateMessageFSM: generateMessageFsm,
}

149
pkg/crdt/vector_clock.go Normal file
View File

@ -0,0 +1,149 @@
package crdt
import (
"cmp"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type VectorBucket struct {
// clock current value of the node's clock
clock uint64
// lastUpdate we've seen
lastUpdate uint64
}
// Vector clock defines an abstract data type
// for a vector clock implementation
type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket
lock sync.RWMutex
processID K
staleTime uint64
hashFunc func(K) uint64
}
// IncrementClock: increments the node's value in the vector clock
func (m *VectorClock[K]) IncrementClock() uint64 {
maxClock := uint64(0)
m.lock.Lock()
for _, value := range m.vectors {
maxClock = max(maxClock, value.clock)
}
newBucket := VectorBucket{
clock: maxClock + 1,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[m.hashFunc(m.processID)] = &newBucket
m.lock.Unlock()
return maxClock
}
// GetHash: gets the hash of the vector clock used to determine if there
// are any changes
func (m *VectorClock[K]) GetHash() uint64 {
m.lock.RLock()
hash := uint64(0)
for key, bucket := range m.vectors {
hash += key * (bucket.clock + 1)
}
m.lock.RUnlock()
return hash
}
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors {
m.put(key, value)
}
}
// getStale: get all entries that are stale within the mesh
func (m *VectorClock[K]) getStale() []uint64 {
m.lock.RLock()
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
return max(i, vb.lastUpdate)
})
toRemove := make([]uint64, 0)
for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key)
}
}
m.lock.RUnlock()
return toRemove
}
func (m *VectorClock[K]) Prune() {
stale := m.getStale()
m.lock.Lock()
for _, key := range stale {
delete(m.vectors, key)
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return m.vectors[m.hashFunc(m.processID)].lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value)
}
func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue := uint64(0)
m.lock.Lock()
bucket, ok := m.vectors[key]
if ok {
clockValue = bucket.clock
}
if value > clockValue {
newBucket := VectorBucket{
clock: value,
lastUpdate: uint64(time.Now().Unix()),
}
m.vectors[key] = &newBucket
}
m.lock.Unlock()
}
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64)
m.lock.RLock()
for key, value := range m.vectors {
clock[key] = value.clock
}
m.lock.RUnlock()
return clock
}
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
return &VectorClock[K]{
vectors: make(map[uint64]*VectorBucket),
processID: processID,
staleTime: staleTime,
hashFunc: hashFunc,
}
}

View File

@ -1,11 +1,13 @@
package lib
import "cmp"
// MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V {
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
}
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude))
i := 0
@ -26,7 +28,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
return values
}
func MapKeys[K comparable, V any](m map[K]V) []K {
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
values := make([]K, len(m))
i := 0
@ -76,3 +78,13 @@ func Contains[V any](list []V, proposition func(V) bool) bool {
return false
}
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
accum := start
for _, elem := range values {
accum = reduce(accum, elem)
}
return accum
}

View File

@ -248,6 +248,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
if route.equal(r) {
return false
}
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
return false
}
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
return false
}
}
return true
}
@ -255,7 +263,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
err := c.DeleteRoute(ifName, route)
if err != nil {

40
pkg/lib/stats.go Normal file
View File

@ -0,0 +1,40 @@
// lib contains helper functions for the implementation
package lib
import (
"cmp"
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"net"
"slices"
"strings"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -52,7 +53,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[node.GetWgHost().String()]
clients, ok := peerToClients[pubKey.String()]
if ok {
allowedips = append(allowedips, clients...)
@ -115,7 +116,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
@ -124,6 +124,13 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
v6Default, _, _ := net.ParseCIDR("::/0")
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute {
return true
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
@ -154,6 +161,142 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
return routes
}
// getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
hashFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc)
return peer
}
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(mesh.GetMeshId())
routes := lib.Map(lib.MapKeys(m.getRoutes(mesh)), func(destination string) net.IPNet {
_, ipNet, _ := net.ParseCIDR(destination)
return *ipNet
})
routes = append(routes, *meshNet)
if err != nil {
return nil, err
}
peer := m.getCorrespondingPeer(peers, self)
pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil {
return nil, err
}
peerCfgs := make([]wgtypes.PeerConfig, 1)
peerCfgs[0] = wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: &keepAlive,
AllowedIPs: routes,
}
installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: route,
})
}
cfg := wgtypes.Config{
Peers: peerCfgs,
}
m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return nil, err
}
for _, n := range clients {
if len(peers) > 0 {
peer := m.getCorrespondingPeer(peers, n)
pubKey, _ := peer.GetPublicKey()
clients, ok := peerToClients[pubKey.String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[pubKey.String()] = clients
}
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return nil, err
}
peerConfigs = append(peerConfigs, *cfg)
}
}
}
for _, n := range peers {
if NodeEquals(n, self) {
continue
}
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return nil, err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0")
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs = append(peerConfigs, *peer)
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
}
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
@ -162,13 +305,19 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
}
nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
dev, _ := mesh.GetDevice()
slices.SortFunc(nodes, func(a, b MeshNode) int {
return strings.Compare(string(a.GetType()), string(b.GetType()))
})
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0
clients := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.CLIENT_ROLE
})
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
@ -176,72 +325,26 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
var cfg *wgtypes.Config = nil
for _, n := range nodes {
if NodeEquals(n, self) {
continue
switch self.GetType() {
case conf.PEER_ROLE:
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
case conf.CLIENT_ROLE:
cfg, err = m.getClientConfig(mesh, peers, clients, dev)
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer
count++
}
cfg := wgtypes.Config{
Peers: peerConfigs,
}
dev, err := mesh.GetDevice()
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
return err
}
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
if err != nil {
return err
}
return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
return nil
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
@ -271,6 +374,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
})
return nil

View File

@ -8,7 +8,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -112,7 +111,7 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
err := mesh.Prune()
if err != nil {
return err
@ -329,7 +328,6 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
@ -473,7 +471,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
m.RouteManager = NewRouteManager(m, &params.Conf)
}
m.idGenerator = params.IdGenerator

View File

@ -1,6 +1,9 @@
package mesh
import (
"net"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
@ -13,6 +16,7 @@ type RouteManager interface {
type RouteManagerImpl struct {
meshManager MeshManager
conf *conf.WgMeshConfiguration
}
func (r *RouteManagerImpl) UpdateRoutes() error {
@ -32,12 +36,23 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
routes, err := mesh1.GetRoutes(pubKey.String())
routeMap, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
if r.conf.AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
mesh1.AddRoutes(NodeID(self),
&RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: make([]string, 0),
})
}
for _, mesh2 := range meshes {
if mesh1 == mesh2 {
continue
@ -50,7 +65,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
routes := lib.MapValues(routeMap)
err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{
Destination: ipNet,
HopCount: 0,
Path: make([]string, 0),
@ -88,6 +105,6 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m}
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
return &RouteManagerImpl{meshManager: m, conf: conf}
}

View File

@ -81,6 +81,11 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
// Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented")
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
@ -117,7 +122,7 @@ func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
return nil
}
// RemoveRoutes implements MeshProvider.
@ -126,7 +131,7 @@ func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
func (*MeshProviderStub) Prune() error {
return nil
}

View File

@ -131,15 +131,18 @@ type MeshProvider interface {
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
// Prune: prunes all nodes that have not updated their
// vector clock
Prune() error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
// Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state
Mark(nodeId string)
}
// HostParameters contains the IDs of a node

View File

@ -19,7 +19,11 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
ip6Routes := lib.Filter(routes, func(r lib.Route) bool {
return r.Destination.IP.To4() == nil
})
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...)
if err != nil {
return err

View File

@ -1,8 +1,8 @@
package sync
import (
"io"
"math/rand"
"sync"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
@ -12,7 +12,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
// Syncer: picks random nodes from the mesh
// Syncer: picks random nodes from the meshs
type Syncer interface {
Sync(meshId string) error
SyncMeshes() error
@ -25,70 +25,95 @@ type SyncerImpl struct {
syncCount int
cluster conn.ConnCluster
conf *conf.WgMeshConfiguration
lastSync uint64
}
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
s.manager.GetMesh(meshId).Prune()
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
logging.Log.WriteInfof("UPDATING WG CONF")
before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
nodeNames := s.manager.GetMesh(meshId).GetPeers()
var gossipNodes []string
// Clients always pings its peer for configuration
if self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
gossipNodes = make([]string, 1)
gossipNodes[0] = neighbour
} else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
logging.Log.WriteInfof("Random node: %s", node)
}
before := time.Now()
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
randomSubset = append(randomSubset, interCluster)
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
}
}
var waitGroup sync.WaitGroup
var succeeded bool = false
for index := range randomSubset {
waitGroup.Add(1)
go func(i int) error {
defer waitGroup.Done()
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node)
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
logging.Log.WriteErrorf("node %s does not exist", node)
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err
}(index)
}
err := s.requester.SyncMesh(meshId, correspondingPeer)
waitGroup.Wait()
if err == nil || err == io.EOF {
succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
}
}
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err = s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
}
@ -98,7 +123,7 @@ func (s *SyncerImpl) SyncMeshes() error {
err := s.Sync(meshId)
if err != nil {
return err
logging.Log.WriteErrorf(err.Error())
}
}

View File

@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
mesh.Mark(nodeId)
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
return s.handleFailed(meshId, nodeId)
}
return false

View File

@ -15,7 +15,7 @@ import (
// SyncRequester: coordinates the syncing of meshes
type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
}
type SyncRequesterImpl struct {
@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok {
return nil
@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
if err != nil {
@ -96,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
return s.handleErr(meshId, pubKey.String(), err)
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)

View File

@ -8,7 +8,8 @@ import (
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
syncer.SyncMeshes()
return nil
}
}

View File

@ -3,7 +3,6 @@ package wg
import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
err = rtnl.CreateLink(md5Str)