smegmesh/pkg/automerge/automerge.go
2023-10-30 19:02:28 +00:00

260 lines
5.5 KiB
Go

package crdt
import (
"errors"
"net"
"strings"
"time"
"github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt)
if !ok {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
}
// GetMeshId returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId
}
// Save: Save an entire mesh network
func (c *CrdtMeshManager) Save() []byte {
return c.doc.Save()
}
// Load: Load an entire mesh network
func (c *CrdtMeshManager) Load(bytes []byte) error {
doc, err := automerge.Load(bytes)
if err != nil {
return err
}
c.doc = doc
return nil
}
type NewCrdtNodeMangerParams struct {
MeshId string
DevName string
Port int
Conf conf.WgMeshConfiguration
Client *wgctrl.Client
}
// NewCrdtNodeManager: Create a new crdt node manager
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
var manager CrdtMeshManager
manager.MeshId = params.MeshId
manager.doc = automerge.New()
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
return &manager, nil
}
func (c *CrdtMeshManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
}
return nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if err != nil {
return nil, err
}
meshNode, err := automerge.As[*MeshNodeCrdt](node)
if err != nil {
return nil, err
}
return meshNode, nil
}
func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len()
}
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since the last time we synced
func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash)
logging.Log.WriteInfof("Changes %s", m.LastHash.String())
if err != nil {
return false
}
logging.Log.WriteInfof("Changes length %d", len(changes))
return len(changes) > 0
}
func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads()
hash := hashes[len(hashes)-1]
logging.Log.WriteInfof("Saved Hash %s", hash.String())
m.LastHash = hash
}
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return errors.New("node is not a map")
}
err = node.Map().Set("timestamp", time.Now().Unix())
if err == nil {
logging.Log.WriteInfof("Timestamp Updated for %s", m.MeshId)
}
return err
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
if err != nil {
return err
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
if err != nil {
return err
}
}
return nil
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint
}
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey)
}
func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.WgEndpoint
}
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil
}
return ipnet
}
func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
}
func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4]
constituents := strings.Split(ipv6, ":")
logging.Log.WriteInfof(ipv6)
constituents = constituents[4:]
return strings.Join(constituents, ":")
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
for _, node := range m.Nodes {
nodes[node.HostEndpoint] = &MeshNodeCrdt{
HostEndpoint: node.HostEndpoint,
WgEndpoint: node.WgEndpoint,
PublicKey: node.PublicKey,
WgHost: node.WgHost,
Timestamp: node.Timestamp,
Routes: node.Routes,
}
}
return nodes
}