Automatically remove nodes from the mesh after a

certain threshold.
This commit is contained in:
Tim Beatham 2023-10-20 17:35:02 +01:00
parent c200544cee
commit 976dbf2613
12 changed files with 191 additions and 34 deletions

View File

@ -5,6 +5,7 @@ import (
"log"
ipcRpc "net/rpc"
"os"
"strconv"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc"
@ -68,6 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount))
fmt.Println("---")
}
}

View File

@ -18,8 +18,12 @@ type CrdtNodeManager struct {
doc *automerge.Doc
}
const maxFails = 5
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
crdt.FailedCount = automerge.NewCounter(0)
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
}
func (c *CrdtNodeManager) applyWg() error {
@ -115,6 +119,83 @@ func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
node, err := c.doc.Path("nodes").Map().Get(endpoint)
if err != nil {
return err
}
counter, err := node.Map().Get("failedCount")
if err != nil {
return err
}
err = counter.Counter().Inc(incAmount)
return err
}
// Increment failed count increments the number of times we have attempted
// to contact the node and it's failed
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
snapshot, err := c.GetCrdt()
if err != nil {
return err
}
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
if err != nil {
return err
}
if count >= maxFails {
c.removeNode(endpoint)
logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId)
return nil
}
if err != nil {
return err
}
return c.changeFailedCount(c.MeshId, endpoint, +1)
}
func (c *CrdtNodeManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
}
return nil
}
// Decrement failed count decrements the number of times we have attempted to
// contact the node and it's failed
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
snapshot, err := c.GetCrdt()
if err != nil {
return err
}
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
if err != nil {
return err
}
if count < 0 {
return nil
}
return c.changeFailedCount(c.MeshId, endpoint, -1)
}
func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))

View File

@ -1,10 +1,14 @@
package crdt
import "github.com/automerge/automerge-go"
type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
FailedCount *automerge.Counter `automerge:"failedCount"`
FailedInt int `automerge:"-"`
}
type MeshCrdt struct {

View File

@ -30,7 +30,7 @@ type NewCtrlServerParams struct {
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer)
ctrlServer.Client = params.WgClient
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient)
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewJwtConnectionManagerParams{

View File

@ -16,6 +16,7 @@ type MeshNode struct {
WgEndpoint string
PublicKey string
WgHost string
FailedCount int
}
type Mesh struct {

View File

@ -7,7 +7,7 @@ import (
"net/rpc"
"os"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
)
type JoinMeshArgs struct {
@ -16,7 +16,7 @@ type JoinMeshArgs struct {
}
type GetMeshReply struct {
Nodes []crdt.MeshNodeCrdt
Nodes []ctrlserver.MeshNode
}
type ListMeshReply struct {

View File

@ -9,7 +9,7 @@ func RandomSubsetOfLength[V any](vs []V, num int) []V {
selectedIndices := make(map[int]struct{})
for i := 0; i < num; {
if len(selectedIndices) == len(vs) {
if len(randomSubset) == len(vs) {
return randomSubset
}

View File

@ -2,16 +2,20 @@ package manager
import (
"errors"
"fmt"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManger struct {
Meshes map[string]*crdt.CrdtNodeManager
Client *wgctrl.Client
Meshes map[string]*crdt.CrdtNodeManager
Client *wgctrl.Client
HostEndpoint string
}
func (m *MeshManger) MeshExists(meshId string) bool {
@ -86,13 +90,7 @@ func (s *MeshManger) EnableInterface(meshId string) error {
return err
}
dev, err := s.Client.Device(mesh.IfName)
if err != nil {
return err
}
node, contains := crdt.Nodes[dev.PublicKey.String()]
node, contains := crdt.Nodes[s.HostEndpoint]
if !contains {
return errors.New("Node does not exist in the mesh")
@ -118,6 +116,12 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
return &dev.PublicKey, nil
}
func NewMeshManager(client wgctrl.Client) *MeshManger {
return &MeshManger{Meshes: make(map[string]*crdt.CrdtNodeManager), Client: &client}
func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger {
ip := lib.GetOutboundIP()
return &MeshManger{
Meshes: make(map[string]*crdt.CrdtNodeManager),
Client: &client,
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
}
}

View File

@ -217,11 +217,21 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
}
if mesh != nil {
nodes := make([]crdt.MeshNodeCrdt, len(meshSnapshot.Nodes))
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
i := 0
for _, n := range meshSnapshot.Nodes {
nodes[i] = n
failedInt, _ := n.FailedCount.Get()
node := ctrlserver.MeshNode{
HostEndpoint: n.HostEndpoint,
WgEndpoint: n.WgEndpoint,
PublicKey: n.PublicKey,
WgHost: n.WgHost,
FailedCount: int(failedInt),
}
nodes[i] = node
i += 1
}

View File

@ -37,14 +37,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err
}
pubKey, err := s.manager.GetPublicKey(meshId)
if err != nil {
return err
}
excludedNodes := map[string]struct{}{
pubKey.String(): {},
s.manager.HostEndpoint: {},
}
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)

49
pkg/sync/syncererror.go Normal file
View File

@ -0,0 +1,49 @@
package sync
import (
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/manager"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool
}
type SyncErrorHandlerImpl struct {
meshManager *manager.MeshManger
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return false
}
err := mesh.IncrementFailedCount(endpoint)
if err != nil {
return false
}
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
errStatus, _ := status.FromError(err)
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
}
return false
}
func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -17,11 +17,11 @@ type SyncRequester interface {
}
type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer
server *ctrlserver.MeshCtrlServer
errorHdlr SyncErrorHandler
}
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
if err != nil {
@ -77,6 +77,16 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
if ok {
return nil
}
return err
}
// SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
if !s.server.ConnectionManager.HasConnection(endpoint) {
@ -92,7 +102,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = peerConnection.Connect()
if err != nil {
return err
return s.handleErr(meshId, endpoint, err)
}
client, err := peerConnection.GetClient()
@ -126,13 +136,15 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
_, err = c.SyncMesh(ctx, &syncMeshRequest)
if err != nil {
return err
return s.handleErr(meshId, endpoint, err)
}
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
mesh.DecrementFailedCount(endpoint)
return nil
}
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
return &SyncRequesterImpl{server: s}
errorHdlr := NewSyncErrorHandler(s.MeshManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
}