forked from extern/smegmesh
650901aba1
Implemented my own two phase map based on vector clocks
636 lines
13 KiB
Go
636 lines
13 KiB
Go
package automerge
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/automerge/automerge-go"
|
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
|
"golang.zx2c4.com/wireguard/wgctrl"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
// CrdtMeshManager manages nodes in the crdt mesh
|
|
type CrdtMeshManager struct {
|
|
MeshId string
|
|
IfName string
|
|
Client *wgctrl.Client
|
|
doc *automerge.Doc
|
|
LastHash automerge.ChangeHash
|
|
conf *conf.WgMeshConfiguration
|
|
cache *MeshCrdt
|
|
lastCacheHash automerge.ChangeHash
|
|
}
|
|
|
|
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
|
|
crdt, ok := node.(*MeshNodeCrdt)
|
|
|
|
if !ok {
|
|
panic("node must be of type *MeshNodeCrdt")
|
|
}
|
|
|
|
crdt.Routes = make(map[string]Route)
|
|
crdt.Services = make(map[string]string)
|
|
crdt.Timestamp = time.Now().Unix()
|
|
|
|
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
|
}
|
|
|
|
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
|
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil || node.Kind() != automerge.KindMap {
|
|
return false
|
|
}
|
|
|
|
nodeType, err := node.Map().Get("type")
|
|
|
|
if err != nil || nodeType.Kind() != automerge.KindStr {
|
|
return false
|
|
}
|
|
|
|
return nodeType.Str() == string(conf.PEER_ROLE)
|
|
}
|
|
|
|
// isAlive: checks that the node's configuration has been updated
|
|
// since the rquired keep alive time
|
|
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
|
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil || node.Kind() != automerge.KindMap {
|
|
return false
|
|
}
|
|
|
|
timestamp, err := node.Map().Get("timestamp")
|
|
|
|
if err != nil || timestamp.Kind() != automerge.KindInt64 {
|
|
return false
|
|
}
|
|
|
|
keepAliveTime := timestamp.Int64()
|
|
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
|
|
}
|
|
|
|
func (c *CrdtMeshManager) GetPeers() []string {
|
|
keys, _ := c.doc.Path("nodes").Map().Keys()
|
|
|
|
keys = lib.Filter(keys, func(publicKey string) bool {
|
|
return c.isPeer(publicKey) && c.isAlive(publicKey)
|
|
})
|
|
|
|
return keys
|
|
}
|
|
|
|
// GetMesh(): Converts the document into a struct
|
|
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
|
changes, err := c.doc.Changes(c.lastCacheHash)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if c.cache == nil || len(changes) > 0 {
|
|
c.lastCacheHash = c.LastHash
|
|
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.cache = cache
|
|
}
|
|
|
|
return c.cache, nil
|
|
}
|
|
|
|
// GetMeshId returns the meshid of the mesh
|
|
func (c *CrdtMeshManager) GetMeshId() string {
|
|
return c.MeshId
|
|
}
|
|
|
|
// Save: Save an entire mesh network
|
|
func (c *CrdtMeshManager) Save() []byte {
|
|
return c.doc.Save()
|
|
}
|
|
|
|
// Load: Load an entire mesh network
|
|
func (c *CrdtMeshManager) Load(bytes []byte) error {
|
|
doc, err := automerge.Load(bytes)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.doc = doc
|
|
return nil
|
|
}
|
|
|
|
type NewCrdtNodeMangerParams struct {
|
|
MeshId string
|
|
DevName string
|
|
Port int
|
|
Conf conf.WgMeshConfiguration
|
|
Client *wgctrl.Client
|
|
}
|
|
|
|
// NewCrdtNodeManager: Create a new crdt node manager
|
|
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
|
|
var manager CrdtMeshManager
|
|
manager.MeshId = params.MeshId
|
|
manager.doc = automerge.New()
|
|
manager.IfName = params.DevName
|
|
manager.Client = params.Client
|
|
manager.conf = ¶ms.Conf
|
|
manager.cache = nil
|
|
return &manager, nil
|
|
}
|
|
|
|
// NodeExists: returns true if the node exists. Returns false
|
|
func (m *CrdtMeshManager) NodeExists(key string) bool {
|
|
node, err := m.doc.Path("nodes").Map().Get(key)
|
|
return node.Kind() == automerge.KindMap && err == nil
|
|
}
|
|
|
|
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
|
node, err := m.doc.Path("nodes").Map().Get(endpoint)
|
|
|
|
if node.Kind() != automerge.KindMap {
|
|
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
meshNode, err := automerge.As[*MeshNodeCrdt](node)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return meshNode, nil
|
|
}
|
|
|
|
func (m *CrdtMeshManager) Length() int {
|
|
return m.doc.Path("nodes").Map().Len()
|
|
}
|
|
|
|
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
|
|
dev, err := m.Client.Device(m.IfName)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return dev, nil
|
|
}
|
|
|
|
// HasChanges returns true if we have changes since the last time we synced
|
|
func (m *CrdtMeshManager) HasChanges() bool {
|
|
changes, err := m.doc.Changes(m.LastHash)
|
|
|
|
logging.Log.WriteInfof("Changes %s", m.LastHash.String())
|
|
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
logging.Log.WriteInfof("Changes length %d", len(changes))
|
|
return len(changes) > 0
|
|
}
|
|
|
|
func (m *CrdtMeshManager) SaveChanges() {
|
|
hashes := m.doc.Heads()
|
|
hash := hashes[len(hashes)-1]
|
|
|
|
logging.Log.WriteInfof("Saved Hash %s", hash.String())
|
|
m.LastHash = hash
|
|
}
|
|
|
|
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if node.Kind() != automerge.KindMap {
|
|
return errors.New("node is not a map")
|
|
}
|
|
|
|
err = node.Map().Set("timestamp", time.Now().Unix())
|
|
|
|
if err == nil {
|
|
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
|
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if node.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("%s does not exist", nodeId)
|
|
}
|
|
|
|
err = node.Map().Set("description", description)
|
|
|
|
if err == nil {
|
|
logging.Log.WriteInfof("Description Updated for %s", nodeId)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if node.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("%s does not exist", nodeId)
|
|
}
|
|
|
|
err = node.Map().Set("alias", alias)
|
|
|
|
if err == nil {
|
|
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil || node.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("AddService: node %s does not exist", nodeId)
|
|
}
|
|
|
|
service, err := node.Map().Get("services")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if service.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("AddService: services property does not exist in node")
|
|
}
|
|
|
|
err = service.Map().Set(key, value)
|
|
return err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
|
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil || node.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
|
|
}
|
|
|
|
service, err := node.Map().Get("services")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if service.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("services property does not exist")
|
|
}
|
|
|
|
err = service.Map().Delete(key)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("service %s does not exist", key)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddRoutes: adds routes to the specific nodeId
|
|
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
logging.Log.WriteInfof("Adding route to %s", nodeId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if nodeVal.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("node does not exist")
|
|
}
|
|
|
|
routeMap, err := nodeVal.Map().Get("routes")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, route := range routes {
|
|
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
|
|
|
|
if prevRoute.Kind() == automerge.KindVoid && err != nil {
|
|
path, err := prevRoute.Map().Get("path")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if path.Kind() != automerge.KindList {
|
|
return fmt.Errorf("path is not a list")
|
|
}
|
|
|
|
pathStr, err := automerge.As[[]string](path)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slices.Equal(route.GetPath(), pathStr)
|
|
}
|
|
|
|
err = routeMap.Map().Set(route.GetDestination().String(), Route{
|
|
Destination: route.GetDestination().String(),
|
|
Path: route.GetPath(),
|
|
})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if nodeVal.Kind() != automerge.KindMap {
|
|
return nil, fmt.Errorf("node does not exist")
|
|
}
|
|
|
|
routeMap, err := nodeVal.Map().Get("routes")
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if routeMap.Kind() != automerge.KindMap {
|
|
return nil, fmt.Errorf("node %s is not a map", nodeId)
|
|
}
|
|
|
|
routes, err := automerge.As[map[string]Route](routeMap)
|
|
|
|
return lib.MapValues(routes), err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
|
|
node, err := m.GetNode(targetNode)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
routes := make(map[string]mesh.Route)
|
|
|
|
// Add routes that the node directly has
|
|
for _, route := range node.GetRoutes() {
|
|
routes[route.GetDestination().String()] = route
|
|
}
|
|
|
|
// Work out the other routes in the mesh
|
|
for _, node := range m.GetPeers() {
|
|
nodeRoutes, err := m.getRoutes(node)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, route := range nodeRoutes {
|
|
otherRoute, ok := routes[route.GetDestination().String()]
|
|
|
|
hopCount := route.GetHopCount()
|
|
|
|
if node != targetNode {
|
|
hopCount += 1
|
|
}
|
|
|
|
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
|
|
routes[route.GetDestination().String()] = &Route{
|
|
Destination: route.GetDestination().String(),
|
|
Path: append(route.Path, m.GetMeshId()),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return routes, nil
|
|
}
|
|
|
|
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
|
err := m.doc.Path("nodes").Map().Delete(nodeId)
|
|
return err
|
|
}
|
|
|
|
// DeleteRoutes deletes the specified routes
|
|
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if nodeVal.Kind() != automerge.KindMap {
|
|
return fmt.Errorf("node is not a map")
|
|
}
|
|
|
|
routeMap, err := nodeVal.Map().Get("routes")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, route := range routes {
|
|
err = routeMap.Map().Delete(route)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
|
return NewAutomergeSync(m)
|
|
}
|
|
|
|
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
|
nodes, err := m.doc.Path("nodes").Get()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if nodes.Kind() != automerge.KindMap {
|
|
return errors.New("node must be a map")
|
|
}
|
|
|
|
values, err := nodes.Map().Values()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
deletionNodes := make([]string, 0)
|
|
|
|
for nodeId, node := range values {
|
|
if node.Kind() != automerge.KindMap {
|
|
return errors.New("node must be a map")
|
|
}
|
|
|
|
nodeMap := node.Map()
|
|
|
|
timeStamp, err := nodeMap.Get("timestamp")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if timeStamp.Kind() != automerge.KindInt64 {
|
|
return errors.New("timestamp is not int64")
|
|
}
|
|
|
|
timeValue := timeStamp.Int64()
|
|
nowValue := time.Now().Unix()
|
|
|
|
if nowValue-timeValue >= int64(pruneTime) {
|
|
deletionNodes = append(deletionNodes, nodeId)
|
|
}
|
|
}
|
|
|
|
for _, node := range deletionNodes {
|
|
logging.Log.WriteInfof("Pruning %s", node)
|
|
nodes.Map().Delete(node)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
|
|
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetHostEndpoint() string {
|
|
return m.HostEndpoint
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
|
|
return wgtypes.ParseKey(m.PublicKey)
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetWgEndpoint() string {
|
|
return m.WgEndpoint
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
|
|
_, ipnet, err := net.ParseCIDR(m.WgHost)
|
|
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return ipnet
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetTimeStamp() int64 {
|
|
return m.Timestamp
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
|
|
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
|
|
return &Route{
|
|
Destination: r.Destination,
|
|
Path: r.Path,
|
|
}
|
|
})
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetDescription() string {
|
|
return m.Description
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetIdentifier() string {
|
|
ipv6 := m.WgHost[:len(m.WgHost)-4]
|
|
|
|
constituents := strings.Split(ipv6, ":")
|
|
constituents = constituents[4:]
|
|
return strings.Join(constituents, ":")
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetAlias() string {
|
|
return m.Alias
|
|
}
|
|
|
|
func (m *MeshNodeCrdt) GetServices() map[string]string {
|
|
services := make(map[string]string)
|
|
|
|
for key, service := range m.Services {
|
|
services[key] = service
|
|
}
|
|
|
|
return services
|
|
}
|
|
|
|
// GetType refers to the type of the node. Peer means that the node is globally accessible
|
|
// Client means the node is only accessible through another peer
|
|
func (n *MeshNodeCrdt) GetType() conf.NodeType {
|
|
return conf.NodeType(n.Type)
|
|
}
|
|
|
|
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
|
|
nodes := make(map[string]mesh.MeshNode)
|
|
|
|
for _, node := range m.Nodes {
|
|
nodes[node.HostEndpoint] = &MeshNodeCrdt{
|
|
HostEndpoint: node.HostEndpoint,
|
|
WgEndpoint: node.WgEndpoint,
|
|
PublicKey: node.PublicKey,
|
|
WgHost: node.WgHost,
|
|
Timestamp: node.Timestamp,
|
|
Routes: node.Routes,
|
|
Description: node.Description,
|
|
Alias: node.Alias,
|
|
Services: node.GetServices(),
|
|
Type: node.Type,
|
|
}
|
|
}
|
|
|
|
return nodes
|
|
}
|
|
|
|
func (r *Route) GetDestination() *net.IPNet {
|
|
_, ipnet, _ := net.ParseCIDR(r.Destination)
|
|
return ipnet
|
|
}
|
|
|
|
func (r *Route) GetHopCount() int {
|
|
return len(r.Path)
|
|
}
|
|
|
|
func (r *Route) GetPath() []string {
|
|
return r.Path
|
|
}
|