Automatically remove nodes from the mesh after a

certain threshold.
This commit is contained in:
Tim Beatham 2023-10-20 17:35:02 +01:00
parent c200544cee
commit 976dbf2613
12 changed files with 191 additions and 34 deletions

View File

@ -5,6 +5,7 @@ import (
"log" "log"
ipcRpc "net/rpc" ipcRpc "net/rpc"
"os" "os"
"strconv"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
@ -68,6 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost) fmt.Println("Wg IP: " + node.WgHost)
fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount))
fmt.Println("---") fmt.Println("---")
} }
} }

View File

@ -18,8 +18,12 @@ type CrdtNodeManager struct {
doc *automerge.Doc doc *automerge.Doc
} }
const maxFails = 5
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) crdt.FailedCount = automerge.NewCounter(0)
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
func (c *CrdtNodeManager) applyWg() error { func (c *CrdtNodeManager) applyWg() error {
@ -115,6 +119,83 @@ func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey) return strings.Compare(m1.PublicKey, m2.PublicKey)
} }
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
node, err := c.doc.Path("nodes").Map().Get(endpoint)
if err != nil {
return err
}
counter, err := node.Map().Get("failedCount")
if err != nil {
return err
}
err = counter.Counter().Inc(incAmount)
return err
}
// Increment failed count increments the number of times we have attempted
// to contact the node and it's failed
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
snapshot, err := c.GetCrdt()
if err != nil {
return err
}
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
if err != nil {
return err
}
if count >= maxFails {
c.removeNode(endpoint)
logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId)
return nil
}
if err != nil {
return err
}
return c.changeFailedCount(c.MeshId, endpoint, +1)
}
func (c *CrdtNodeManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
return err
}
return nil
}
// Decrement failed count decrements the number of times we have attempted to
// contact the node and it's failed
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
snapshot, err := c.GetCrdt()
if err != nil {
return err
}
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
if err != nil {
return err
}
if count < 0 {
return nil
}
return c.changeFailedCount(c.MeshId, endpoint, -1)
}
func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) peerConfigs := make([]wgtypes.PeerConfig, len(nodes))

View File

@ -1,10 +1,14 @@
package crdt package crdt
import "github.com/automerge/automerge-go"
type MeshNodeCrdt struct { type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"` HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
FailedCount *automerge.Counter `automerge:"failedCount"`
FailedInt int `automerge:"-"`
} }
type MeshCrdt struct { type MeshCrdt struct {

View File

@ -30,7 +30,7 @@ type NewCtrlServerParams struct {
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
ctrlServer.Client = params.WgClient ctrlServer.Client = params.WgClient
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient) ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf)
ctrlServer.Conf = params.Conf ctrlServer.Conf = params.Conf
connManagerParams := conn.NewJwtConnectionManagerParams{ connManagerParams := conn.NewJwtConnectionManagerParams{

View File

@ -16,6 +16,7 @@ type MeshNode struct {
WgEndpoint string WgEndpoint string
PublicKey string PublicKey string
WgHost string WgHost string
FailedCount int
} }
type Mesh struct { type Mesh struct {

View File

@ -7,7 +7,7 @@ import (
"net/rpc" "net/rpc"
"os" "os"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
) )
type JoinMeshArgs struct { type JoinMeshArgs struct {
@ -16,7 +16,7 @@ type JoinMeshArgs struct {
} }
type GetMeshReply struct { type GetMeshReply struct {
Nodes []crdt.MeshNodeCrdt Nodes []ctrlserver.MeshNode
} }
type ListMeshReply struct { type ListMeshReply struct {

View File

@ -9,7 +9,7 @@ func RandomSubsetOfLength[V any](vs []V, num int) []V {
selectedIndices := make(map[int]struct{}) selectedIndices := make(map[int]struct{})
for i := 0; i < num; { for i := 0; i < num; {
if len(selectedIndices) == len(vs) { if len(randomSubset) == len(vs) {
return randomSubset return randomSubset
} }

View File

@ -2,16 +2,20 @@ package manager
import ( import (
"errors" "errors"
"fmt"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge" crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type MeshManger struct { type MeshManger struct {
Meshes map[string]*crdt.CrdtNodeManager Meshes map[string]*crdt.CrdtNodeManager
Client *wgctrl.Client Client *wgctrl.Client
HostEndpoint string
} }
func (m *MeshManger) MeshExists(meshId string) bool { func (m *MeshManger) MeshExists(meshId string) bool {
@ -86,13 +90,7 @@ func (s *MeshManger) EnableInterface(meshId string) error {
return err return err
} }
dev, err := s.Client.Device(mesh.IfName) node, contains := crdt.Nodes[s.HostEndpoint]
if err != nil {
return err
}
node, contains := crdt.Nodes[dev.PublicKey.String()]
if !contains { if !contains {
return errors.New("Node does not exist in the mesh") return errors.New("Node does not exist in the mesh")
@ -118,6 +116,12 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
return &dev.PublicKey, nil return &dev.PublicKey, nil
} }
func NewMeshManager(client wgctrl.Client) *MeshManger { func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger {
return &MeshManger{Meshes: make(map[string]*crdt.CrdtNodeManager), Client: &client} ip := lib.GetOutboundIP()
return &MeshManger{
Meshes: make(map[string]*crdt.CrdtNodeManager),
Client: &client,
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
}
} }

View File

@ -217,11 +217,21 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
} }
if mesh != nil { if mesh != nil {
nodes := make([]crdt.MeshNodeCrdt, len(meshSnapshot.Nodes)) nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
i := 0 i := 0
for _, n := range meshSnapshot.Nodes { for _, n := range meshSnapshot.Nodes {
nodes[i] = n failedInt, _ := n.FailedCount.Get()
node := ctrlserver.MeshNode{
HostEndpoint: n.HostEndpoint,
WgEndpoint: n.WgEndpoint,
PublicKey: n.PublicKey,
WgHost: n.WgHost,
FailedCount: int(failedInt),
}
nodes[i] = node
i += 1 i += 1
} }

View File

@ -37,14 +37,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
return err return err
} }
pubKey, err := s.manager.GetPublicKey(meshId)
if err != nil {
return err
}
excludedNodes := map[string]struct{}{ excludedNodes := map[string]struct{}{
pubKey.String(): {}, s.manager.HostEndpoint: {},
} }
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)

49
pkg/sync/syncererror.go Normal file
View File

@ -0,0 +1,49 @@
package sync
import (
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/manager"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool
}
type SyncErrorHandlerImpl struct {
meshManager *manager.MeshManger
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return false
}
err := mesh.IncrementFailedCount(endpoint)
if err != nil {
return false
}
return true
}
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
errStatus, _ := status.FromError(err)
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
return s.incrementFailedCount(meshId, endpoint)
}
return false
}
func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -17,11 +17,11 @@ type SyncRequester interface {
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer server *ctrlserver.MeshCtrlServer
errorHdlr SyncErrorHandler
} }
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error { func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint) peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
if err != nil { if err != nil {
@ -77,6 +77,16 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
return err return err
} }
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
ok := s.errorHdlr.Handle(meshId, endpoint, err)
if ok {
return nil
}
return err
}
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
if !s.server.ConnectionManager.HasConnection(endpoint) { if !s.server.ConnectionManager.HasConnection(endpoint) {
@ -92,7 +102,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
err = peerConnection.Connect() err = peerConnection.Connect()
if err != nil { if err != nil {
return err return s.handleErr(meshId, endpoint, err)
} }
client, err := peerConnection.GetClient() client, err := peerConnection.GetClient()
@ -126,13 +136,15 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
_, err = c.SyncMesh(ctx, &syncMeshRequest) _, err = c.SyncMesh(ctx, &syncMeshRequest)
if err != nil { if err != nil {
return err return s.handleErr(meshId, endpoint, err)
} }
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
mesh.DecrementFailedCount(endpoint)
return nil return nil
} }
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
return &SyncRequesterImpl{server: s} errorHdlr := NewSyncErrorHandler(s.MeshManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
} }