mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-08-14 07:18:32 +02:00
Compare commits
10 Commits
39-impleme
...
43-gravest
Author | SHA1 | Date | |
---|---|---|---|
a3ceff019d | |||
b78d96986c | |||
1b18d89c9f | |||
245a2c5f58 | |||
c40f7510b8 | |||
78d748770c | |||
0ff2a8eef9 | |||
fd7bd80485 | |||
3ef1b68ba5 | |||
b9ba836ae3 |
@ -48,6 +48,13 @@ type MeshNode struct {
|
||||
Description string
|
||||
Services map[string]string
|
||||
Type string
|
||||
Tombstone bool
|
||||
}
|
||||
|
||||
// Mark: marks the node is unreachable. This is not broadcast on
|
||||
// syncrhonisation
|
||||
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
|
||||
m.store.Mark(nodeId)
|
||||
}
|
||||
|
||||
// GetHostEndpoint: gets the gRPC endpoint of the node
|
||||
@ -200,11 +207,11 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
|
||||
// Load() loads a mesh network
|
||||
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
|
||||
buf := bytes.NewBuffer(bs)
|
||||
|
||||
dec := gob.NewDecoder(buf)
|
||||
|
||||
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
|
||||
err := dec.Decode(&snapshot)
|
||||
|
||||
m.store.Merge(snapshot)
|
||||
return err
|
||||
}
|
||||
@ -222,13 +229,13 @@ func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
|
||||
|
||||
// HasChanges returns true if we have changes since last time we synced
|
||||
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
|
||||
clockValue := m.store.GetClock()
|
||||
clockValue := m.store.GetHash()
|
||||
return clockValue != m.LastClock
|
||||
}
|
||||
|
||||
// Record that we have changes and save the corresponding changes
|
||||
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
|
||||
clockValue := m.store.GetClock()
|
||||
clockValue := m.store.GetHash()
|
||||
m.LastClock = clockValue
|
||||
}
|
||||
|
||||
@ -256,14 +263,25 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
|
||||
|
||||
node := m.store.Get(nodeId)
|
||||
|
||||
changes := false
|
||||
|
||||
for _, route := range routes {
|
||||
node.Routes[route.GetDestination().String()] = Route{
|
||||
Destination: route.GetDestination().String(),
|
||||
Path: route.GetPath(),
|
||||
prevRoute, ok := node.Routes[route.GetDestination().String()]
|
||||
|
||||
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
|
||||
changes = true
|
||||
|
||||
node.Routes[route.GetDestination().String()] = Route{
|
||||
Destination: route.GetDestination().String(),
|
||||
Path: route.GetPath(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.store.Put(nodeId, node)
|
||||
if changes {
|
||||
m.store.Put(nodeId, node)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -357,8 +375,18 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
|
||||
}
|
||||
|
||||
// Prune: prunes all nodes that have not updated their timestamp in
|
||||
// pruneAmount seconds
|
||||
// pruneAmount of seconds
|
||||
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
|
||||
nodes := lib.MapValues(m.store.AsMap())
|
||||
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
||||
return time.Now().Unix()-mn.Timestamp > int64(pruneAmount)
|
||||
})
|
||||
|
||||
for _, node := range nodes {
|
||||
key, _ := node.GetPublicKey()
|
||||
m.store.Remove(key.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -370,6 +398,13 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the node is marked as unreachable don't consider it a peer.
|
||||
// this help to optimize convergence time for unreachable nodes.
|
||||
// However advertising it to other nodes could result in flapping.
|
||||
if m.store.IsMarked(mn.PublicKey) {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
|
||||
})
|
||||
|
||||
|
@ -6,8 +6,9 @@ import (
|
||||
)
|
||||
|
||||
type Bucket[D any] struct {
|
||||
Vector uint64
|
||||
Contents D
|
||||
Vector uint64
|
||||
Contents D
|
||||
Gravestone bool
|
||||
}
|
||||
|
||||
// GMap is a set that can only grow in size
|
||||
@ -62,6 +63,30 @@ func (g *GMap[K, D]) Get(key K) D {
|
||||
return g.get(key).Contents
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Mark(key K) {
|
||||
g.lock.Lock()
|
||||
bucket := g.contents[key]
|
||||
bucket.Gravestone = true
|
||||
g.lock.Unlock()
|
||||
}
|
||||
|
||||
// IsMarked: returns true if the node is marked
|
||||
func (g *GMap[K, D]) IsMarked(key K) bool {
|
||||
marked := false
|
||||
|
||||
g.lock.RLock()
|
||||
|
||||
bucket, ok := g.contents[key]
|
||||
|
||||
if ok {
|
||||
marked = bucket.Gravestone
|
||||
}
|
||||
|
||||
g.lock.RUnlock()
|
||||
|
||||
return marked
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Keys() []K {
|
||||
g.lock.RLock()
|
||||
|
||||
|
@ -60,6 +60,10 @@ func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||
m.addMap.Put(key, data)
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) Mark(key K) {
|
||||
m.addMap.Mark(key)
|
||||
}
|
||||
|
||||
// Remove removes the value from the map
|
||||
func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
||||
m.removeMap.Put(key, true)
|
||||
@ -115,6 +119,10 @@ type TwoPhaseMapState[K comparable] struct {
|
||||
RemoveContents map[K]uint64
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
||||
return m.addMap.IsMarked(key)
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
|
||||
maxClock := uint64(0)
|
||||
m.lock.Lock()
|
||||
@ -128,16 +136,19 @@ func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
|
||||
return maxClock
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) GetClock() uint64 {
|
||||
maxClock := uint64(0)
|
||||
// GetHash: Get the hash of the current state of the map
|
||||
// Sums the current values of the vectors. Provides good approximation
|
||||
// of increasing numbers
|
||||
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
||||
m.lock.RLock()
|
||||
|
||||
for _, value := range m.vectors {
|
||||
maxClock = max(maxClock, value)
|
||||
}
|
||||
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
|
||||
return current + sum
|
||||
})
|
||||
|
||||
m.lock.RUnlock()
|
||||
return maxClock
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// GetState: get the current vector clock of the add and remove
|
||||
@ -181,11 +192,15 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
||||
m.lock.Lock()
|
||||
|
||||
for key, value := range snapshot.Add {
|
||||
// Gravestone is local only to that node.
|
||||
// Discover ourselves if the node is alive
|
||||
value.Gravestone = false
|
||||
m.addMap.put(key, value)
|
||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
||||
}
|
||||
|
||||
for key, value := range snapshot.Remove {
|
||||
value.Gravestone = false
|
||||
m.removeMap.put(key, value)
|
||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
||||
}
|
||||
|
@ -125,7 +125,6 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
||||
|
||||
func (t *TwoPhaseSyncer) Complete() {
|
||||
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||
t.manager.SaveChanges()
|
||||
}
|
||||
|
||||
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
||||
|
@ -76,3 +76,13 @@ func Contains[V any](list []V, proposition func(V) bool) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func Reduce[A any, V any](start A, values []V, reduce func(A, V) A) A {
|
||||
accum := start
|
||||
|
||||
for _, elem := range values {
|
||||
accum = reduce(accum, elem)
|
||||
}
|
||||
|
||||
return accum
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
@ -52,7 +53,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
||||
allowedips := make([]net.IPNet, 1)
|
||||
allowedips[0] = *node.GetWgHost()
|
||||
|
||||
clients, ok := peerToClients[node.GetWgHost().String()]
|
||||
clients, ok := peerToClients[pubKey.String()]
|
||||
|
||||
if ok {
|
||||
allowedips = append(allowedips, clients...)
|
||||
@ -154,59 +155,100 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
return routes
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
snap, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
// getCorrespondignPeer: gets the peer corresponding to the client
|
||||
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
|
||||
hashFunc := func(mn MeshNode) int {
|
||||
pubKey, _ := mn.GetPublicKey()
|
||||
return lib.HashString(pubKey.String())
|
||||
}
|
||||
|
||||
nodes := lib.MapValues(snap.GetNodes())
|
||||
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
||||
|
||||
peers := lib.Filter(nodes, func(mn MeshNode) bool {
|
||||
return mn.GetType() == conf.PEER_ROLE
|
||||
})
|
||||
|
||||
var count int = 0
|
||||
peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc)
|
||||
return peer
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) {
|
||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := m.getCorrespondingPeer(peers, self)
|
||||
|
||||
pubKey, _ := peer.GetPublicKey()
|
||||
|
||||
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
||||
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedips := make([]net.IPNet, 1)
|
||||
_, ipnet, _ := net.ParseCIDR("::/0")
|
||||
allowedips[0] = *ipnet
|
||||
|
||||
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
||||
|
||||
peerCfgs[0] = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: endpoint,
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
AllowedIPs: allowedips,
|
||||
}
|
||||
|
||||
cfg := wgtypes.Config{
|
||||
Peers: peerCfgs,
|
||||
}
|
||||
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
|
||||
peerToClients := make(map[string][]net.IPNet)
|
||||
routes := m.getRoutes(mesh)
|
||||
installedRoutes := make([]lib.Route, 0)
|
||||
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
||||
|
||||
for _, n := range nodes {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, n := range clients {
|
||||
if len(peers) > 0 {
|
||||
peer := m.getCorrespondingPeer(peers, n)
|
||||
pubKey, _ := peer.GetPublicKey()
|
||||
clients, ok := peerToClients[pubKey.String()]
|
||||
|
||||
if !ok {
|
||||
clients = make([]net.IPNet, 0)
|
||||
peerToClients[pubKey.String()] = clients
|
||||
}
|
||||
|
||||
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
||||
|
||||
if NodeEquals(self, peer) {
|
||||
cfg, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerConfigs = append(peerConfigs, *cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range peers {
|
||||
if NodeEquals(n, self) {
|
||||
continue
|
||||
}
|
||||
|
||||
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
|
||||
hashFunc := func(mn MeshNode) int {
|
||||
return lib.HashString(mn.GetWgHost().String())
|
||||
}
|
||||
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
|
||||
|
||||
clients, ok := peerToClients[peer.GetWgHost().String()]
|
||||
|
||||
if !ok {
|
||||
clients = make([]net.IPNet, 0)
|
||||
peerToClients[peer.GetWgHost().String()] = clients
|
||||
}
|
||||
|
||||
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
|
||||
continue
|
||||
}
|
||||
|
||||
dev, _ := mesh.GetDevice()
|
||||
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, route := range peer.AllowedIPs {
|
||||
@ -221,27 +263,66 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
}
|
||||
}
|
||||
|
||||
peerConfigs[count] = *peer
|
||||
count++
|
||||
peerConfigs = append(peerConfigs, *peer)
|
||||
}
|
||||
|
||||
cfg := wgtypes.Config{
|
||||
Peers: peerConfigs,
|
||||
Peers: peerConfigs,
|
||||
ReplacePeers: true,
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
snap, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
|
||||
nodes := lib.MapValues(snap.GetNodes())
|
||||
dev, _ := mesh.GetDevice()
|
||||
|
||||
slices.SortFunc(nodes, func(a, b MeshNode) int {
|
||||
return strings.Compare(string(a.GetType()), string(b.GetType()))
|
||||
})
|
||||
|
||||
peers := lib.Filter(nodes, func(mn MeshNode) bool {
|
||||
return mn.GetType() == conf.PEER_ROLE
|
||||
})
|
||||
|
||||
clients := lib.Filter(nodes, func(mn MeshNode) bool {
|
||||
return mn.GetType() == conf.CLIENT_ROLE
|
||||
})
|
||||
|
||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
||||
var cfg *wgtypes.Config = nil
|
||||
|
||||
switch self.GetType() {
|
||||
case conf.PEER_ROLE:
|
||||
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
|
||||
case conf.CLIENT_ROLE:
|
||||
cfg, err = m.getClientConfig(mesh, peers, clients)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||
@ -270,7 +351,8 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
|
||||
}
|
||||
|
||||
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
|
||||
Peers: make([]wgtypes.PeerConfig, 0),
|
||||
Peers: make([]wgtypes.PeerConfig, 0),
|
||||
ReplacePeers: true,
|
||||
})
|
||||
|
||||
return nil
|
||||
|
@ -81,6 +81,11 @@ type MeshProviderStub struct {
|
||||
snapshot *MeshSnapshotStub
|
||||
}
|
||||
|
||||
// Mark implements MeshProvider.
|
||||
func (*MeshProviderStub) Mark(nodeId string) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// RemoveNode implements MeshProvider.
|
||||
func (*MeshProviderStub) RemoveNode(nodeId string) error {
|
||||
panic("unimplemented")
|
||||
@ -117,7 +122,7 @@ func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
|
||||
|
||||
// SetAlias implements MeshProvider.
|
||||
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
|
||||
panic("unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutes implements MeshProvider.
|
||||
|
@ -140,6 +140,9 @@ type MeshProvider interface {
|
||||
GetRoutes(targetNode string) (map[string]Route, error)
|
||||
// RemoveNode(): remove the node from the mesh
|
||||
RemoveNode(nodeId string) error
|
||||
// Mark: marks the node as unreachable. This is not broadcast to the entire
|
||||
// this is not considered when syncing node state
|
||||
Mark(nodeId string)
|
||||
}
|
||||
|
||||
// HostParameters contains the IDs of a node
|
||||
|
@ -1,8 +1,8 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
@ -59,36 +59,43 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
||||
logging.Log.WriteInfof("Sending to random cluster")
|
||||
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
||||
randomSubset = append(randomSubset, interCluster)
|
||||
randomSubset[len(randomSubset)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
||||
}
|
||||
|
||||
var waitGroup sync.WaitGroup
|
||||
var succeeded bool = false
|
||||
|
||||
for index := range randomSubset {
|
||||
waitGroup.Add(1)
|
||||
// Do this synchronously to conserve bandwidth
|
||||
for _, node := range randomSubset {
|
||||
correspondingPeer := s.manager.GetNode(meshId, node)
|
||||
|
||||
go func(i int) error {
|
||||
defer waitGroup.Done()
|
||||
if correspondingPeer == nil {
|
||||
logging.Log.WriteErrorf("node %s does not exist", node)
|
||||
}
|
||||
|
||||
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
|
||||
err = s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
|
||||
|
||||
if correspondingPeer == nil {
|
||||
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
|
||||
}
|
||||
|
||||
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
|
||||
return err
|
||||
}(index)
|
||||
if err == nil || err == io.EOF {
|
||||
succeeded = true
|
||||
} else {
|
||||
// If the synchronisation operation has failed them mark a gravestone
|
||||
// preventing the peer from being re-contacted until it has updated
|
||||
// itself
|
||||
s.manager.GetMesh(meshId).Mark(node)
|
||||
}
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
|
||||
s.syncCount++
|
||||
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
|
||||
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
|
||||
|
||||
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
|
||||
|
||||
if !succeeded {
|
||||
// If could not gossip with anyone then repeat.
|
||||
s.infectionCount++
|
||||
}
|
||||
|
||||
s.manager.GetMesh(meshId).SaveChanges()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user