forked from extern/smegmesh
39-implement-two-phase-map
Implemented my own two phase map based on vector clocks
This commit is contained in:
parent
a82eab0686
commit
650901aba1
@ -1,4 +1,4 @@
|
|||||||
package crdt
|
package automerge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@ -6,7 +6,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/automerge/automerge-go"
|
"github.com/automerge/automerge-go"
|
||||||
@ -20,7 +19,6 @@ import (
|
|||||||
|
|
||||||
// CrdtMeshManager manages nodes in the crdt mesh
|
// CrdtMeshManager manages nodes in the crdt mesh
|
||||||
type CrdtMeshManager struct {
|
type CrdtMeshManager struct {
|
||||||
lock sync.RWMutex
|
|
||||||
MeshId string
|
MeshId string
|
||||||
IfName string
|
IfName string
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
@ -42,13 +40,10 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
|
|||||||
crdt.Services = make(map[string]string)
|
crdt.Services = make(map[string]string)
|
||||||
crdt.Timestamp = time.Now().Unix()
|
crdt.Timestamp = time.Now().Unix()
|
||||||
|
|
||||||
c.lock.Lock()
|
|
||||||
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
||||||
c.lock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
||||||
c.lock.RLock()
|
|
||||||
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil || node.Kind() != automerge.KindMap {
|
if err != nil || node.Kind() != automerge.KindMap {
|
||||||
@ -56,7 +51,6 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nodeType, err := node.Map().Get("type")
|
nodeType, err := node.Map().Get("type")
|
||||||
c.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil || nodeType.Kind() != automerge.KindStr {
|
if err != nil || nodeType.Kind() != automerge.KindStr {
|
||||||
return false
|
return false
|
||||||
@ -68,7 +62,6 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
|||||||
// isAlive: checks that the node's configuration has been updated
|
// isAlive: checks that the node's configuration has been updated
|
||||||
// since the rquired keep alive time
|
// since the rquired keep alive time
|
||||||
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
||||||
c.lock.RLock()
|
|
||||||
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil || node.Kind() != automerge.KindMap {
|
if err != nil || node.Kind() != automerge.KindMap {
|
||||||
@ -76,7 +69,6 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
timestamp, err := node.Map().Get("timestamp")
|
timestamp, err := node.Map().Get("timestamp")
|
||||||
c.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil || timestamp.Kind() != automerge.KindInt64 {
|
if err != nil || timestamp.Kind() != automerge.KindInt64 {
|
||||||
return false
|
return false
|
||||||
@ -87,9 +79,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *CrdtMeshManager) GetPeers() []string {
|
func (c *CrdtMeshManager) GetPeers() []string {
|
||||||
c.lock.RLock()
|
|
||||||
keys, _ := c.doc.Path("nodes").Map().Keys()
|
keys, _ := c.doc.Path("nodes").Map().Keys()
|
||||||
c.lock.RUnlock()
|
|
||||||
|
|
||||||
keys = lib.Filter(keys, func(publicKey string) bool {
|
keys = lib.Filter(keys, func(publicKey string) bool {
|
||||||
return c.isPeer(publicKey) && c.isAlive(publicKey)
|
return c.isPeer(publicKey) && c.isAlive(publicKey)
|
||||||
@ -108,9 +98,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
|||||||
|
|
||||||
if c.cache == nil || len(changes) > 0 {
|
if c.cache == nil || len(changes) > 0 {
|
||||||
c.lastCacheHash = c.LastHash
|
c.lastCacheHash = c.LastHash
|
||||||
c.lock.RLock()
|
|
||||||
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
|
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
|
||||||
c.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -170,7 +158,6 @@ func (m *CrdtMeshManager) NodeExists(key string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(endpoint)
|
node, err := m.doc.Path("nodes").Map().Get(endpoint)
|
||||||
|
|
||||||
if node.Kind() != automerge.KindMap {
|
if node.Kind() != automerge.KindMap {
|
||||||
@ -182,7 +169,6 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
meshNode, err := automerge.As[*MeshNodeCrdt](node)
|
meshNode, err := automerge.As[*MeshNodeCrdt](node)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -228,9 +214,7 @@ func (m *CrdtMeshManager) SaveChanges() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -240,9 +224,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
|||||||
return errors.New("node is not a map")
|
return errors.New("node is not a map")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = node.Map().Set("timestamp", time.Now().Unix())
|
err = node.Map().Set("timestamp", time.Now().Unix())
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
|
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
|
||||||
@ -252,9 +234,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
|
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -264,9 +244,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
|
|||||||
return fmt.Errorf("%s does not exist", nodeId)
|
return fmt.Errorf("%s does not exist", nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = node.Map().Set("description", description)
|
err = node.Map().Set("description", description)
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Log.WriteInfof("Description Updated for %s", nodeId)
|
logging.Log.WriteInfof("Description Updated for %s", nodeId)
|
||||||
@ -276,9 +254,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -288,9 +264,7 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
|||||||
return fmt.Errorf("%s does not exist", nodeId)
|
return fmt.Errorf("%s does not exist", nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = node.Map().Set("alias", alias)
|
err = node.Map().Set("alias", alias)
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
|
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
|
||||||
@ -300,17 +274,13 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil || node.Kind() != automerge.KindMap {
|
if err != nil || node.Kind() != automerge.KindMap {
|
||||||
return fmt.Errorf("AddService: node %s does not exist", nodeId)
|
return fmt.Errorf("AddService: node %s does not exist", nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.RLock()
|
|
||||||
service, err := node.Map().Get("services")
|
service, err := node.Map().Get("services")
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -320,14 +290,11 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
|||||||
return fmt.Errorf("AddService: services property does not exist in node")
|
return fmt.Errorf("AddService: services property does not exist in node")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = service.Map().Set(key, value)
|
err = service.Map().Set(key, value)
|
||||||
m.lock.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
||||||
m.lock.RLock()
|
|
||||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil || node.Kind() != automerge.KindMap {
|
if err != nil || node.Kind() != automerge.KindMap {
|
||||||
@ -343,11 +310,8 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
|||||||
if service.Kind() != automerge.KindMap {
|
if service.Kind() != automerge.KindMap {
|
||||||
return fmt.Errorf("services property does not exist")
|
return fmt.Errorf("services property does not exist")
|
||||||
}
|
}
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = service.Map().Delete(key)
|
err = service.Map().Delete(key)
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("service %s does not exist", key)
|
return fmt.Errorf("service %s does not exist", key)
|
||||||
@ -358,7 +322,6 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
|||||||
|
|
||||||
// AddRoutes: adds routes to the specific nodeId
|
// AddRoutes: adds routes to the specific nodeId
|
||||||
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
m.lock.RLock()
|
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
logging.Log.WriteInfof("Adding route to %s", nodeId)
|
logging.Log.WriteInfof("Adding route to %s", nodeId)
|
||||||
|
|
||||||
@ -371,7 +334,6 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
routeMap, err := nodeVal.Map().Get("routes")
|
routeMap, err := nodeVal.Map().Get("routes")
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -400,12 +362,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
|||||||
slices.Equal(route.GetPath(), pathStr)
|
slices.Equal(route.GetPath(), pathStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
err = routeMap.Map().Set(route.GetDestination().String(), Route{
|
err = routeMap.Map().Set(route.GetDestination().String(), Route{
|
||||||
Destination: route.GetDestination().String(),
|
Destination: route.GetDestination().String(),
|
||||||
Path: route.GetPath(),
|
Path: route.GetPath(),
|
||||||
})
|
})
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -415,7 +375,6 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
||||||
m.lock.RLock()
|
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -437,7 +396,6 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
routes, err := automerge.As[map[string]Route](routeMap)
|
routes, err := automerge.As[map[string]Route](routeMap)
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
return lib.MapValues(routes), err
|
return lib.MapValues(routes), err
|
||||||
}
|
}
|
||||||
@ -486,15 +444,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
||||||
m.lock.Lock()
|
|
||||||
err := m.doc.Path("nodes").Map().Delete(nodeId)
|
err := m.doc.Path("nodes").Map().Delete(nodeId)
|
||||||
m.lock.Unlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes deletes the specified routes
|
// DeleteRoutes deletes the specified routes
|
||||||
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
||||||
m.lock.RLock()
|
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -506,17 +461,14 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
routeMap, err := nodeVal.Map().Get("routes")
|
routeMap, err := nodeVal.Map().Get("routes")
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
err = routeMap.Map().Delete(route)
|
err = routeMap.Map().Delete(route)
|
||||||
}
|
}
|
||||||
m.lock.Unlock()
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -526,7 +478,6 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
||||||
m.lock.RLock()
|
|
||||||
nodes, err := m.doc.Path("nodes").Get()
|
nodes, err := m.doc.Path("nodes").Get()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -538,7 +489,6 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
values, err := nodes.Map().Values()
|
values, err := nodes.Map().Values()
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -553,9 +503,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
|||||||
|
|
||||||
nodeMap := node.Map()
|
nodeMap := node.Map()
|
||||||
|
|
||||||
m.lock.RLock()
|
|
||||||
timeStamp, err := nodeMap.Get("timestamp")
|
timeStamp, err := nodeMap.Get("timestamp")
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -601,7 +549,6 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
|
|||||||
_, ipnet, err := net.ParseCIDR(m.WgHost)
|
_, ipnet, err := net.ParseCIDR(m.WgHost)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -629,7 +576,6 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
|
|||||||
ipv6 := m.WgHost[:len(m.WgHost)-4]
|
ipv6 := m.WgHost[:len(m.WgHost)-4]
|
||||||
|
|
||||||
constituents := strings.Split(ipv6, ":")
|
constituents := strings.Split(ipv6, ":")
|
||||||
logging.Log.WriteInfof(ipv6)
|
|
||||||
constituents = constituents[4:]
|
constituents = constituents[4:]
|
||||||
return strings.Join(constituents, ":")
|
return strings.Join(constituents, ":")
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package crdt
|
package automerge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/automerge/automerge-go"
|
"github.com/automerge/automerge-go"
|
||||||
@ -32,6 +32,7 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
|
|||||||
|
|
||||||
func (a *AutomergeSync) Complete() {
|
func (a *AutomergeSync) Complete() {
|
||||||
logging.Log.WriteInfof("Sync Completed")
|
logging.Log.WriteInfof("Sync Completed")
|
||||||
|
a.manager.SaveChanges()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
|
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package crdt
|
package automerge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package crdt
|
package automerge
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package crdt
|
package automerge
|
||||||
|
|
||||||
// Route: Represents a CRDT of the given route
|
// Route: Represents a CRDT of the given route
|
||||||
type Route struct {
|
type Route struct {
|
||||||
|
442
pkg/crdt/datastore.go
Normal file
442
pkg/crdt/datastore.go
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
Destination string
|
||||||
|
Path []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDestination implements mesh.Route.
|
||||||
|
func (r *Route) GetDestination() *net.IPNet {
|
||||||
|
_, ipnet, _ := net.ParseCIDR(r.Destination)
|
||||||
|
return ipnet
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHopCount implements mesh.Route.
|
||||||
|
func (r *Route) GetHopCount() int {
|
||||||
|
return len(r.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPath implements mesh.Route.
|
||||||
|
func (r *Route) GetPath() []string {
|
||||||
|
return r.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
type MeshNode struct {
|
||||||
|
HostEndpoint string
|
||||||
|
WgEndpoint string
|
||||||
|
PublicKey string
|
||||||
|
WgHost string
|
||||||
|
Timestamp int64
|
||||||
|
Routes map[string]Route
|
||||||
|
Alias string
|
||||||
|
Description string
|
||||||
|
Services map[string]string
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHostEndpoint: gets the gRPC endpoint of the node
|
||||||
|
func (n *MeshNode) GetHostEndpoint() string {
|
||||||
|
return n.HostEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicKey: gets the public key of the node
|
||||||
|
func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) {
|
||||||
|
return wgtypes.ParseKey(n.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWgEndpoint(): get IP and port of the wireguard endpoint
|
||||||
|
func (n *MeshNode) GetWgEndpoint() string {
|
||||||
|
return n.WgEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWgHost: get the IP address of the WireGuard node
|
||||||
|
func (n *MeshNode) GetWgHost() *net.IPNet {
|
||||||
|
_, ipnet, _ := net.ParseCIDR(n.WgHost)
|
||||||
|
return ipnet
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimestamp: get the UNIX time stamp of the ndoe
|
||||||
|
func (n *MeshNode) GetTimeStamp() int64 {
|
||||||
|
return n.Timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes: returns the routes that the nodes provides
|
||||||
|
func (n *MeshNode) GetRoutes() []mesh.Route {
|
||||||
|
routes := make([]mesh.Route, len(n.Routes))
|
||||||
|
|
||||||
|
for index, route := range lib.MapValues(n.Routes) {
|
||||||
|
routes[index] = &Route{
|
||||||
|
Destination: route.Destination,
|
||||||
|
Path: route.Path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentifier: returns the identifier of the node
|
||||||
|
func (m *MeshNode) GetIdentifier() string {
|
||||||
|
ipv6 := m.WgHost[:len(m.WgHost)-4]
|
||||||
|
|
||||||
|
constituents := strings.Split(ipv6, ":")
|
||||||
|
constituents = constituents[4:]
|
||||||
|
return strings.Join(constituents, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDescription: returns the description for this node
|
||||||
|
func (n *MeshNode) GetDescription() string {
|
||||||
|
return n.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAlias: associates the node with an alias. Potentially used
|
||||||
|
// for DNS and so forth.
|
||||||
|
func (n *MeshNode) GetAlias() string {
|
||||||
|
return n.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServices: returns a list of services offered by the node
|
||||||
|
func (n *MeshNode) GetServices() map[string]string {
|
||||||
|
return n.Services
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *MeshNode) GetType() conf.NodeType {
|
||||||
|
return conf.NodeType(n.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MeshSnapshot struct {
|
||||||
|
Nodes map[string]MeshNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodes() returns the nodes in the mesh
|
||||||
|
func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
|
||||||
|
newMap := make(map[string]mesh.MeshNode)
|
||||||
|
|
||||||
|
for key, value := range m.Nodes {
|
||||||
|
newMap[key] = &MeshNode{
|
||||||
|
HostEndpoint: value.HostEndpoint,
|
||||||
|
PublicKey: value.PublicKey,
|
||||||
|
WgHost: value.WgHost,
|
||||||
|
WgEndpoint: value.WgEndpoint,
|
||||||
|
Timestamp: value.Timestamp,
|
||||||
|
Routes: value.Routes,
|
||||||
|
Alias: value.Alias,
|
||||||
|
Description: value.Description,
|
||||||
|
Services: value.Services,
|
||||||
|
Type: value.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoPhaseStoreMeshManager struct {
|
||||||
|
MeshId string
|
||||||
|
IfName string
|
||||||
|
Client *wgctrl.Client
|
||||||
|
LastClock uint64
|
||||||
|
conf *conf.WgMeshConfiguration
|
||||||
|
store *TwoPhaseMap[string, MeshNode]
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNode() adds a node to the mesh
|
||||||
|
func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
|
||||||
|
crdt, ok := node.(*MeshNode)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
panic("node must be of type mesh node")
|
||||||
|
}
|
||||||
|
|
||||||
|
crdt.Routes = make(map[string]Route)
|
||||||
|
crdt.Services = make(map[string]string)
|
||||||
|
crdt.Timestamp = time.Now().Unix()
|
||||||
|
|
||||||
|
m.store.Put(crdt.PublicKey, *crdt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||||
|
return &MeshSnapshot{
|
||||||
|
Nodes: m.store.AsMap(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMeshId() returns the ID of the mesh network
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetMeshId() string {
|
||||||
|
return m.MeshId
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save() saves the mesh network
|
||||||
|
func (m *TwoPhaseStoreMeshManager) Save() []byte {
|
||||||
|
snapshot := m.store.Snapshot()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buf)
|
||||||
|
|
||||||
|
err := enc.Encode(*snapshot)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load() loads a mesh network
|
||||||
|
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
|
||||||
|
buf := bytes.NewBuffer(bs)
|
||||||
|
|
||||||
|
dec := gob.NewDecoder(buf)
|
||||||
|
|
||||||
|
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
|
||||||
|
err := dec.Decode(&snapshot)
|
||||||
|
m.store.Merge(snapshot)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevice() get the device corresponding with the mesh
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) {
|
||||||
|
dev, err := m.Client.Device(m.IfName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasChanges returns true if we have changes since last time we synced
|
||||||
|
func (m *TwoPhaseStoreMeshManager) HasChanges() bool {
|
||||||
|
clockValue := m.store.GetClock()
|
||||||
|
return clockValue != m.LastClock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record that we have changes and save the corresponding changes
|
||||||
|
func (m *TwoPhaseStoreMeshManager) SaveChanges() {
|
||||||
|
clockValue := m.store.GetClock()
|
||||||
|
m.LastClock = clockValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTimeStamp: update the timestamp of the given node
|
||||||
|
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
node.Timestamp = time.Now().Unix()
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoutes: adds routes to the given node
|
||||||
|
func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
node.Routes[route.GetDestination().String()] = Route{
|
||||||
|
Destination: route.GetDestination().String(),
|
||||||
|
Path: route.GetPath(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoutes: deletes the routes from the node
|
||||||
|
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
delete(node.Routes, route)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSyncer: returns the automerge syncer for sync
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
|
||||||
|
return NewTwoPhaseSyncer(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNode get a particular not within the mesh
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
return &node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeExists: returns true if a particular node exists false otherwise
|
||||||
|
func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool {
|
||||||
|
return m.store.Contains(nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDescription: sets the description of this automerge data type
|
||||||
|
func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
node.Description = description
|
||||||
|
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAlias: set the alias of the nodeId
|
||||||
|
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
node.Description = alias
|
||||||
|
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddService: adds the service to the given node
|
||||||
|
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
node.Services[key] = value
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveService: removes the service form the node. throws an error if the service does not exist
|
||||||
|
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(nodeId)
|
||||||
|
delete(node.Services, key)
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prune: prunes all nodes that have not updated their timestamp in
|
||||||
|
// pruneAmount seconds
|
||||||
|
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeers: get a list of contactable peers
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
||||||
|
nodes := lib.MapValues(m.store.AsMap())
|
||||||
|
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
||||||
|
if mn.Type != string(conf.PEER_ROLE) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
|
||||||
|
})
|
||||||
|
|
||||||
|
return lib.Map(nodes, func(mn MeshNode) string {
|
||||||
|
return mn.PublicKey
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
|
||||||
|
if !m.store.Contains(targetNode) {
|
||||||
|
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
node := m.store.Get(targetNode)
|
||||||
|
return node.Routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
|
||||||
|
node, err := m.GetNode(targetNode)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make(map[string]mesh.Route)
|
||||||
|
|
||||||
|
// Add routes that the node directly has
|
||||||
|
for _, route := range node.GetRoutes() {
|
||||||
|
routes[route.GetDestination().String()] = route
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work out the other routes in the mesh
|
||||||
|
for _, node := range m.GetPeers() {
|
||||||
|
nodeRoutes, err := m.getRoutes(node)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range nodeRoutes {
|
||||||
|
otherRoute, ok := routes[route.GetDestination().String()]
|
||||||
|
|
||||||
|
hopCount := route.GetHopCount()
|
||||||
|
|
||||||
|
if node != targetNode {
|
||||||
|
hopCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
|
||||||
|
routes[route.GetDestination().String()] = &Route{
|
||||||
|
Destination: route.GetDestination().String(),
|
||||||
|
Path: append(route.GetPath(), m.GetMeshId()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNode(): remove the node from the mesh
|
||||||
|
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
|
||||||
|
if !m.store.Contains(nodeId) {
|
||||||
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.store.Remove(nodeId)
|
||||||
|
return nil
|
||||||
|
}
|
73
pkg/crdt/factory.go
Normal file
73
pkg/crdt/factory.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TwoPhaseMapFactory struct{}
|
||||||
|
|
||||||
|
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
|
||||||
|
return &TwoPhaseStoreMeshManager{
|
||||||
|
MeshId: params.MeshId,
|
||||||
|
IfName: params.DevName,
|
||||||
|
Client: params.Client,
|
||||||
|
conf: params.Conf,
|
||||||
|
store: NewTwoPhaseMap[string, MeshNode](params.NodeID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MeshNodeFactory struct {
|
||||||
|
Config conf.WgMeshConfiguration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
|
||||||
|
hostName := f.getAddress(params)
|
||||||
|
|
||||||
|
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
|
||||||
|
|
||||||
|
if f.Config.Role == conf.CLIENT_ROLE {
|
||||||
|
grpcEndpoint = "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MeshNode{
|
||||||
|
HostEndpoint: grpcEndpoint,
|
||||||
|
PublicKey: params.PublicKey.String(),
|
||||||
|
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
|
||||||
|
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
|
||||||
|
Routes: make(map[string]Route),
|
||||||
|
Description: "",
|
||||||
|
Alias: "",
|
||||||
|
Type: string(f.Config.Role),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAddress returns the routable address of the machine.
|
||||||
|
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
|
||||||
|
var hostName string = ""
|
||||||
|
|
||||||
|
if params.Endpoint != "" {
|
||||||
|
hostName = params.Endpoint
|
||||||
|
} else if len(f.Config.Endpoint) != 0 {
|
||||||
|
hostName = f.Config.Endpoint
|
||||||
|
} else {
|
||||||
|
ipFunc := lib.GetPublicIP
|
||||||
|
|
||||||
|
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
|
||||||
|
ipFunc = lib.GetOutboundIP
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := ipFunc()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
hostName = ip.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostName
|
||||||
|
}
|
121
pkg/crdt/g_map.go
Normal file
121
pkg/crdt/g_map.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
// crdt is a golang implementation of a crdt
|
||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bucket[D any] struct {
|
||||||
|
Vector uint64
|
||||||
|
Contents D
|
||||||
|
}
|
||||||
|
|
||||||
|
// GMap is a set that can only grow in size
|
||||||
|
type GMap[K comparable, D any] struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
contents map[K]Bucket[D]
|
||||||
|
getClock func() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Put(key K, value D) {
|
||||||
|
g.lock.Lock()
|
||||||
|
|
||||||
|
clock := g.getClock() + 1
|
||||||
|
|
||||||
|
g.contents[key] = Bucket[D]{
|
||||||
|
Vector: clock,
|
||||||
|
Contents: value,
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Contains(key K) bool {
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
_, ok := g.contents[key]
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
||||||
|
g.lock.Lock()
|
||||||
|
|
||||||
|
if g.contents[key].Vector < b.Vector {
|
||||||
|
g.contents[key] = b
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) get(key K) Bucket[D] {
|
||||||
|
g.lock.RLock()
|
||||||
|
bucket := g.contents[key]
|
||||||
|
g.lock.RUnlock()
|
||||||
|
|
||||||
|
return bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Get(key K) D {
|
||||||
|
return g.get(key).Contents
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Keys() []K {
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
contents := make([]K, len(g.contents))
|
||||||
|
index := 0
|
||||||
|
|
||||||
|
for key := range g.contents {
|
||||||
|
contents[index] = key
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return contents
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
||||||
|
buckets := make(map[K]Bucket[D])
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
for key, value := range g.contents {
|
||||||
|
buckets[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return buckets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
||||||
|
buckets := make(map[K]Bucket[D])
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
buckets[key] = g.contents[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return buckets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
||||||
|
clock := make(map[K]uint64)
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
for key, bucket := range g.contents {
|
||||||
|
clock[key] = bucket.Vector
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return clock
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] {
|
||||||
|
return &GMap[K, D]{
|
||||||
|
contents: make(map[K]Bucket[D]),
|
||||||
|
getClock: getClock,
|
||||||
|
}
|
||||||
|
}
|
208
pkg/crdt/two_phase_map.go
Normal file
208
pkg/crdt/two_phase_map.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TwoPhaseMap[K comparable, D any] struct {
|
||||||
|
addMap *GMap[K, D]
|
||||||
|
removeMap *GMap[K, bool]
|
||||||
|
vectors map[K]uint64
|
||||||
|
processId K
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoPhaseMapSnapshot[K comparable, D any] struct {
|
||||||
|
Add map[K]Bucket[D]
|
||||||
|
Remove map[K]Bucket[bool]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks whether the value exists in the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
||||||
|
if !m.addMap.Contains(key) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
addValue := m.addMap.get(key)
|
||||||
|
|
||||||
|
if !m.removeMap.Contains(key) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
removeValue := m.removeMap.get(key)
|
||||||
|
|
||||||
|
return addValue.Vector >= removeValue.Vector
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) Get(key K) D {
|
||||||
|
var result D
|
||||||
|
|
||||||
|
if !m.Contains(key) {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.addMap.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put places the key K in the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||||
|
msgSequence := m.incrementClock()
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
if _, ok := m.vectors[key]; !ok {
|
||||||
|
m.vectors[key] = msgSequence
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
m.addMap.Put(key, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes the value from the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
||||||
|
m.removeMap.Put(key, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) Keys() []K {
|
||||||
|
keys := make([]K, 0)
|
||||||
|
|
||||||
|
addKeys := m.addMap.Keys()
|
||||||
|
|
||||||
|
for _, key := range addKeys {
|
||||||
|
if !m.Contains(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
|
||||||
|
theMap := make(map[K]D)
|
||||||
|
|
||||||
|
keys := m.Keys()
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
theMap[key] = m.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return theMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
||||||
|
return &TwoPhaseMapSnapshot[K, D]{
|
||||||
|
Add: m.addMap.Save(),
|
||||||
|
Remove: m.removeMap.Save(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
|
||||||
|
addKeys := lib.MapKeys(state.AddContents)
|
||||||
|
removeKeys := lib.MapKeys(state.RemoveContents)
|
||||||
|
|
||||||
|
return &TwoPhaseMapSnapshot[K, D]{
|
||||||
|
Add: m.addMap.SaveWithKeys(addKeys),
|
||||||
|
Remove: m.removeMap.SaveWithKeys(removeKeys),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TwoPhaseMapState[K comparable] struct {
|
||||||
|
AddContents map[K]uint64
|
||||||
|
RemoveContents map[K]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
|
||||||
|
maxClock := uint64(0)
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for _, value := range m.vectors {
|
||||||
|
maxClock = max(maxClock, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.vectors[m.processId] = maxClock + 1
|
||||||
|
m.lock.Unlock()
|
||||||
|
return maxClock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) GetClock() uint64 {
|
||||||
|
maxClock := uint64(0)
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
for _, value := range m.vectors {
|
||||||
|
maxClock = max(maxClock, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return maxClock
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState: get the current vector clock of the add and remove
|
||||||
|
// map
|
||||||
|
func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
||||||
|
addContents := m.addMap.GetClock()
|
||||||
|
removeContents := m.removeMap.GetClock()
|
||||||
|
|
||||||
|
return &TwoPhaseMapState[K]{
|
||||||
|
AddContents: addContents,
|
||||||
|
RemoveContents: removeContents,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
||||||
|
mapState := &TwoPhaseMapState[K]{
|
||||||
|
AddContents: make(map[K]uint64),
|
||||||
|
RemoveContents: make(map[K]uint64),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range state.AddContents {
|
||||||
|
otherValue, ok := m.AddContents[key]
|
||||||
|
|
||||||
|
if !ok || otherValue < value {
|
||||||
|
mapState.AddContents[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range state.AddContents {
|
||||||
|
otherValue, ok := m.RemoveContents[key]
|
||||||
|
|
||||||
|
if !ok || otherValue < value {
|
||||||
|
mapState.RemoveContents[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for key, value := range snapshot.Add {
|
||||||
|
m.addMap.put(key, value)
|
||||||
|
m.vectors[key] = max(value.Vector, m.vectors[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range snapshot.Remove {
|
||||||
|
m.removeMap.put(key, value)
|
||||||
|
m.vectors[key] = max(value.Vector, m.vectors[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
|
||||||
|
// a grow map and a remove map. If both timestamps equal then favour keeping
|
||||||
|
// it in the map
|
||||||
|
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
|
||||||
|
m := TwoPhaseMap[K, D]{
|
||||||
|
vectors: make(map[K]uint64),
|
||||||
|
processId: processId,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addMap = NewGMap[K, D](m.incrementClock)
|
||||||
|
m.removeMap = NewGMap[K, bool](m.incrementClock)
|
||||||
|
return &m
|
||||||
|
}
|
145
pkg/crdt/two_phase_map_syncer.go
Normal file
145
pkg/crdt/two_phase_map_syncer.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SyncState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PREPARE SyncState = iota
|
||||||
|
PRESENT
|
||||||
|
EXCHANGE
|
||||||
|
MERGE
|
||||||
|
FINISHED
|
||||||
|
)
|
||||||
|
|
||||||
|
// TwoPhaseSyncer is a type to sync a TwoPhase data store
|
||||||
|
type TwoPhaseSyncer struct {
|
||||||
|
manager *TwoPhaseStoreMeshManager
|
||||||
|
generateMessageFSM SyncFSM
|
||||||
|
state SyncState
|
||||||
|
mapState *TwoPhaseMapState[string]
|
||||||
|
peerMsg []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
|
||||||
|
|
||||||
|
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buffer)
|
||||||
|
|
||||||
|
err := enc.Encode(*syncer.mapState)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
syncer.IncrementState()
|
||||||
|
return buffer.Bytes(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
if syncer.peerMsg == nil {
|
||||||
|
panic("peer msg is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
|
||||||
|
dec := gob.NewDecoder(recvBuffer)
|
||||||
|
|
||||||
|
var mapState TwoPhaseMapState[string]
|
||||||
|
err := dec.Decode(&mapState)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
difference := syncer.mapState.Difference(&mapState)
|
||||||
|
|
||||||
|
var sendBuffer bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&sendBuffer)
|
||||||
|
enc.Encode(*difference)
|
||||||
|
|
||||||
|
syncer.IncrementState()
|
||||||
|
return sendBuffer.Bytes(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
if syncer.peerMsg == nil {
|
||||||
|
panic("peer msg is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
|
||||||
|
dec := gob.NewDecoder(recvBuffer)
|
||||||
|
|
||||||
|
var mapState TwoPhaseMapState[string]
|
||||||
|
dec.Decode(&mapState)
|
||||||
|
|
||||||
|
snapshot := syncer.manager.store.SnapShotFromState(&mapState)
|
||||||
|
|
||||||
|
var sendBuffer bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&sendBuffer)
|
||||||
|
enc.Encode(*snapshot)
|
||||||
|
|
||||||
|
syncer.IncrementState()
|
||||||
|
return sendBuffer.Bytes(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
if syncer.peerMsg == nil {
|
||||||
|
panic("peer msg is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
|
||||||
|
dec := gob.NewDecoder(recvBuffer)
|
||||||
|
|
||||||
|
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
|
||||||
|
dec.Decode(&snapshot)
|
||||||
|
|
||||||
|
syncer.manager.store.Merge(snapshot)
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TwoPhaseSyncer) IncrementState() {
|
||||||
|
t.state = min(t.state+1, FINISHED)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) {
|
||||||
|
fsmFunc, ok := t.generateMessageFSM[t.state]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
panic("state not handled")
|
||||||
|
}
|
||||||
|
|
||||||
|
return fsmFunc(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
||||||
|
t.peerMsg = msg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TwoPhaseSyncer) Complete() {
|
||||||
|
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||||
|
t.manager.SaveChanges()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
||||||
|
var generateMessageFsm SyncFSM = SyncFSM{
|
||||||
|
PREPARE: prepare,
|
||||||
|
PRESENT: present,
|
||||||
|
EXCHANGE: exchange,
|
||||||
|
MERGE: merge,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TwoPhaseSyncer{
|
||||||
|
manager: manager,
|
||||||
|
state: PREPARE,
|
||||||
|
mapState: manager.store.GenerateMessage(),
|
||||||
|
generateMessageFSM: generateMessageFsm,
|
||||||
|
}
|
||||||
|
}
|
@ -1,9 +1,9 @@
|
|||||||
package ctrlserver
|
package ctrlserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conn"
|
"github.com/tim-beatham/wgmesh/pkg/conn"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/crdt"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
@ -28,8 +28,8 @@ type NewCtrlServerParams struct {
|
|||||||
// operation failed
|
// operation failed
|
||||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||||
ctrlServer := new(MeshCtrlServer)
|
ctrlServer := new(MeshCtrlServer)
|
||||||
meshFactory := crdt.CrdtProviderFactory{}
|
meshFactory := &crdt.TwoPhaseMapFactory{}
|
||||||
nodeFactory := crdt.MeshNodeFactory{
|
nodeFactory := &crdt.MeshNodeFactory{
|
||||||
Config: *params.Conf,
|
Config: *params.Conf,
|
||||||
}
|
}
|
||||||
idGenerator := &lib.IDNameGenerator{}
|
idGenerator := &lib.IDNameGenerator{}
|
||||||
@ -41,8 +41,8 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
|||||||
meshManagerParams := &mesh.NewMeshManagerParams{
|
meshManagerParams := &mesh.NewMeshManagerParams{
|
||||||
Conf: *params.Conf,
|
Conf: *params.Conf,
|
||||||
Client: params.Client,
|
Client: params.Client,
|
||||||
MeshProvider: &meshFactory,
|
MeshProvider: meshFactory,
|
||||||
NodeFactory: &nodeFactory,
|
NodeFactory: nodeFactory,
|
||||||
IdGenerator: idGenerator,
|
IdGenerator: idGenerator,
|
||||||
IPAllocator: ipAllocator,
|
IPAllocator: ipAllocator,
|
||||||
InterfaceManipulator: interfaceManipulator,
|
InterfaceManipulator: interfaceManipulator,
|
||||||
|
@ -226,8 +226,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerConfigs,
|
Peers: peerConfigs,
|
||||||
ReplacePeers: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err := mesh.GetDevice()
|
dev, err := mesh.GetDevice()
|
||||||
|
@ -146,6 +146,7 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
|||||||
Conf: m.conf,
|
Conf: m.conf,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
MeshId: meshId,
|
MeshId: meshId,
|
||||||
|
NodeID: m.HostParameters.GetPublicKey(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -183,6 +184,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
|||||||
Conf: m.conf,
|
Conf: m.conf,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
|
NodeID: m.HostParameters.GetPublicKey(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -214,11 +216,6 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
|
|||||||
|
|
||||||
// GetPublicKey: Gets the public key of the WireGuard mesh
|
// GetPublicKey: Gets the public key of the WireGuard mesh
|
||||||
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
|
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
|
||||||
if s.conf.StubWg {
|
|
||||||
zeroedKey := make([]byte, wgtypes.KeyLen)
|
|
||||||
return (*wgtypes.Key)(zeroedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := s.HostParameters.PrivateKey.PublicKey()
|
key := s.HostParameters.PrivateKey.PublicKey()
|
||||||
return &key
|
return &key
|
||||||
}
|
}
|
||||||
|
@ -159,6 +159,7 @@ type MeshProviderFactoryParams struct {
|
|||||||
Port int
|
Port int
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.WgMeshConfiguration
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
|
NodeID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// MeshProviderFactory creates an instance of a mesh provider
|
// MeshProviderFactory creates an instance of a mesh provider
|
||||||
|
@ -45,6 +45,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
|
|
||||||
publicKey := s.manager.GetPublicKey()
|
publicKey := s.manager.GetPublicKey()
|
||||||
|
|
||||||
|
logging.Log.WriteInfof(publicKey.String())
|
||||||
|
|
||||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
||||||
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
||||||
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
||||||
@ -87,11 +89,6 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
|
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
|
||||||
|
|
||||||
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
|
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
|
||||||
|
|
||||||
// Check if any changes have occurred and trigger callbacks
|
|
||||||
// if changes have occurred.
|
|
||||||
// return s.manager.GetMonitor().Trigger()
|
|
||||||
s.manager.GetMesh(meshId).SaveChanges()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,10 +3,12 @@ package timer
|
|||||||
import (
|
import (
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
|
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
|
||||||
timerFunc := func() error {
|
timerFunc := func() error {
|
||||||
|
logging.Log.WriteInfof("Updated Timestamp")
|
||||||
return ctrlServer.MeshManager.UpdateTimeStamp()
|
return ctrlServer.MeshManager.UpdateTimeStamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user