diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 7a70f16..95c1e35 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -69,7 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) { fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("Wg IP: " + node.WgHost) - fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount)) + fmt.Println("Failed Count: " + strconv.FormatBool(node.Failed)) fmt.Println("---") } } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 7fc4b04..b6cfff8 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -1,6 +1,7 @@ package crdt import ( + "errors" "net" "strings" @@ -14,6 +15,7 @@ import ( type CrdtNodeManager struct { MeshId string IfName string + NodeId string Client *wgctrl.Client doc *automerge.Doc } @@ -21,19 +23,18 @@ type CrdtNodeManager struct { const maxFails = 5 func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { - crdt.FailedCount = automerge.NewCounter(0) + crdt.FailedMap = automerge.NewMap() c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) - } -func (c *CrdtNodeManager) applyWg() error { +func (c *CrdtNodeManager) ApplyWg() error { snapshot, err := c.GetCrdt() if err != nil { return err } - updateWgConf(c.IfName, snapshot.Nodes, *c.Client) + c.updateWgConf(c.IfName, snapshot.Nodes, *c.Client) return nil } @@ -51,7 +52,6 @@ func (c *CrdtNodeManager) Load(bytes []byte) error { } c.doc = doc - c.applyWg() return nil } @@ -67,7 +67,7 @@ func (c *CrdtNodeManager) LoadChanges(changes []byte) error { return err } - return c.applyWg() + return nil } func (c *CrdtNodeManager) SaveChanges() []byte { @@ -75,16 +75,17 @@ func (c *CrdtNodeManager) SaveChanges() []byte { } // NewCrdtNodeManager: Create a new crdt node manager -func NewCrdtNodeManager(meshId, devName string, client *wgctrl.Client) *CrdtNodeManager { +func NewCrdtNodeManager(meshId, hostId, devName string, client *wgctrl.Client) *CrdtNodeManager { var manager CrdtNodeManager manager.MeshId = meshId manager.doc = automerge.New() manager.IfName = devName manager.Client = client + manager.NodeId = hostId return &manager } -func convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { +func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { peerEndpoint, err := net.ResolveUDPAddr("udp", node.WgEndpoint) if err != nil { @@ -108,6 +109,7 @@ func convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { peerConfig := wgtypes.PeerConfig{ PublicKey: peerPublic, + Remove: m.HasFailed(node.HostEndpoint), Endpoint: peerEndpoint, AllowedIPs: allowedIps, } @@ -126,41 +128,30 @@ func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount i return err } - counter, err := node.Map().Get("failedCount") + counterMap, err := node.Map().Get("failedMap") - if err != nil { - return err + if counterMap.Kind() == automerge.KindVoid { + return errors.New("Something went wrong map does not exist") + } + + counter, _ := counterMap.Map().Get(c.NodeId) + + if counter.Kind() == automerge.KindVoid { + err = counterMap.Map().Set(c.NodeId, incAmount) + } else { + if counter.Int64()+incAmount < 0 { + return nil + } + + err = counterMap.Map().Set(c.NodeId, counter.Int64()+1) } - err = counter.Counter().Inc(incAmount) return err } // Increment failed count increments the number of times we have attempted // to contact the node and it's failed func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error { - snapshot, err := c.GetCrdt() - - if err != nil { - return err - } - - count, err := snapshot.Nodes[endpoint].FailedCount.Get() - - if err != nil { - return err - } - - if count >= maxFails { - c.removeNode(endpoint) - logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId) - return nil - } - - if err != nil { - return err - } - return c.changeFailedCount(c.MeshId, endpoint, +1) } @@ -177,32 +168,69 @@ func (c *CrdtNodeManager) removeNode(endpoint string) error { // Decrement failed count decrements the number of times we have attempted to // contact the node and it's failed func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error { - snapshot, err := c.GetCrdt() - - if err != nil { - return err - } - - count, err := snapshot.Nodes[endpoint].FailedCount.Get() - - if err != nil { - return err - } - - if count < 0 { - return nil - } - return c.changeFailedCount(c.MeshId, endpoint, -1) } -func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { +// GetNode: returns a mesh node crdt. +func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { + node, err := m.doc.Path("nodes").Map().Get(endpoint) + + if err != nil { + return nil, err + } + + meshNode, err := automerge.As[*MeshNodeCrdt](node) + + if err != nil { + return nil, err + } + + return meshNode, nil +} + +const threshold = 2 +const thresholdVotes = 0.1 + +func (m *CrdtNodeManager) Length() int { + return m.doc.Path("nodes").Map().Len() +} + +func (m *CrdtNodeManager) HasFailed(endpoint string) bool { + node, err := m.GetNode(endpoint) + + if err != nil { + logging.InfoLog.Printf("Cannot get node node: %s\n", endpoint) + return true + } + + values, err := node.FailedMap.Values() + + if err != nil { + return true + } + + countFailed := 0 + + for _, value := range values { + count := value.Int64() + + if count >= threshold { + countFailed++ + } + } + + logging.InfoLog.Printf("Count Failed Value: %d\n", countFailed) + logging.InfoLog.Printf("Threshold Value: %d\n", int(thresholdVotes*float64(m.Length())+1)) + return countFailed >= int(thresholdVotes*float64(m.Length())+1) +} + +func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) var count int = 0 for _, n := range nodes { - peer, err := convertMeshNode(n) + peer, err := m.convertMeshNode(n) logging.InfoLog.Println(n.HostEndpoint) if err != nil { diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 60f89b7..c4f9e1f 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -3,12 +3,12 @@ package crdt import "github.com/automerge/automerge-go" type MeshNodeCrdt struct { - HostEndpoint string `automerge:"hostEndpoint"` - WgEndpoint string `automerge:"wgEndpoint"` - PublicKey string `automerge:"publicKey"` - WgHost string `automerge:"wgHost"` - FailedCount *automerge.Counter `automerge:"failedCount"` - FailedInt int `automerge:"-"` + HostEndpoint string `automerge:"hostEndpoint"` + WgEndpoint string `automerge:"wgEndpoint"` + PublicKey string `automerge:"publicKey"` + WgHost string `automerge:"wgHost"` + FailedMap *automerge.Map `automerge:"failedMap"` + FailedInt int `automerge:"-"` } type MeshCrdt struct { diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index ba5f2a4..53c301e 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -16,7 +16,7 @@ type MeshNode struct { WgEndpoint string PublicKey string WgHost string - FailedCount int + Failed bool } type Mesh struct { diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index f148c9c..a322188 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -9,6 +9,11 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{} values := make([]V, len(m)-len(exclude)) i := 0 + + if len(m)-len(exclude) <= 0 { + return values + } + for k, v := range m { if _, excluded := exclude[k]; excluded { continue diff --git a/pkg/manager/mesh_manager.go b/pkg/manager/mesh_manager.go index a120976..4393f61 100644 --- a/pkg/manager/mesh_manager.go +++ b/pkg/manager/mesh_manager.go @@ -31,7 +31,7 @@ func (m *MeshManger) CreateMesh(devName string) (string, error) { return "", err } - nodeManager := crdt.NewCrdtNodeManager(key.String(), devName, m.Client) + nodeManager := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, m.Client) m.Meshes[key.String()] = nodeManager return key.String(), nil } @@ -53,9 +53,22 @@ func (m *MeshManger) UpdateMesh(meshId string, changes []byte) error { return nil } +// ApplyWg: applies the wireguard configuration changes +func (m *MeshManger) ApplyWg() error { + for _, mesh := range m.Meshes { + err := mesh.ApplyWg() + + if err != nil { + return err + } + } + + return nil +} + // AddMesh: Add the mesh to the list of meshes func (m *MeshManger) AddMesh(meshId string, devName string, meshBytes []byte) error { - mesh := crdt.NewCrdtNodeManager(meshId, devName, m.Client) + mesh := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, m.Client) err := mesh.Load(meshBytes) if err != nil { diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index bff6b53..f2206d8 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -24,9 +24,18 @@ type RobinIpc struct { func (n *RobinIpc) CreateMesh(name string, reply *string) error { wg.CreateInterface(n.Server.Conf.IfName) - meshId, err := n.Server.MeshManager.CreateMesh("wgmesh") + meshId, err := n.Server.MeshManager.CreateMesh(n.Server.Conf.IfName) + + if err != nil { + return err + } pubKey, err := n.Server.MeshManager.GetPublicKey(meshId) + + if err != nil { + return err + } + nodeIP, err := n.ipAllocator.GetIP(*pubKey, meshId) if err != nil { @@ -220,15 +229,13 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error { nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes)) i := 0 - for _, n := range meshSnapshot.Nodes { - failedInt, _ := n.FailedCount.Get() - + for _, node := range meshSnapshot.Nodes { node := ctrlserver.MeshNode{ - HostEndpoint: n.HostEndpoint, - WgEndpoint: n.WgEndpoint, - PublicKey: n.PublicKey, - WgHost: n.WgHost, - FailedCount: int(failedInt), + HostEndpoint: node.HostEndpoint, + WgEndpoint: node.WgEndpoint, + PublicKey: node.PublicKey, + WgHost: node.WgHost, + Failed: mesh.HasFailed(node.HostEndpoint), } nodes[i] = node diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 723c2b4..abc4bc5 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -37,6 +37,10 @@ func (s *SyncerImpl) Sync(meshId string) error { return err } + if len(snapshot.Nodes) <= 1 { + return nil + } + excludedNodes := map[string]struct{}{ s.manager.HostEndpoint: {}, } @@ -65,7 +69,7 @@ func (s *SyncerImpl) SyncMeshes() error { } } - return nil + return s.manager.ApplyWg() } func NewSyncer(m *manager.MeshManger, r SyncRequester) Syncer {