mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-03-06 09:31:13 +01:00
Automatically remove nodes from the mesh after a
certain threshold.
This commit is contained in:
parent
c200544cee
commit
976dbf2613
@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
ipcRpc "net/rpc"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/akamensky/argparse"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
||||
@ -68,6 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
|
||||
fmt.Println("Control Endpoint: " + node.HostEndpoint)
|
||||
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
|
||||
fmt.Println("Wg IP: " + node.WgHost)
|
||||
fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount))
|
||||
fmt.Println("---")
|
||||
}
|
||||
}
|
||||
|
@ -18,8 +18,12 @@ type CrdtNodeManager struct {
|
||||
doc *automerge.Doc
|
||||
}
|
||||
|
||||
const maxFails = 5
|
||||
|
||||
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
|
||||
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
||||
crdt.FailedCount = automerge.NewCounter(0)
|
||||
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
|
||||
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) applyWg() error {
|
||||
@ -115,6 +119,83 @@ func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
|
||||
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
|
||||
node, err := c.doc.Path("nodes").Map().Get(endpoint)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
counter, err := node.Map().Get("failedCount")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = counter.Counter().Inc(incAmount)
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment failed count increments the number of times we have attempted
|
||||
// to contact the node and it's failed
|
||||
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
|
||||
snapshot, err := c.GetCrdt()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count >= maxFails {
|
||||
c.removeNode(endpoint)
|
||||
logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.changeFailedCount(c.MeshId, endpoint, +1)
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) removeNode(endpoint string) error {
|
||||
err := c.doc.Path("nodes").Map().Delete(endpoint)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrement failed count decrements the number of times we have attempted to
|
||||
// contact the node and it's failed
|
||||
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
|
||||
snapshot, err := c.GetCrdt()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.changeFailedCount(c.MeshId, endpoint, -1)
|
||||
}
|
||||
|
||||
func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
||||
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
package crdt
|
||||
|
||||
import "github.com/automerge/automerge-go"
|
||||
|
||||
type MeshNodeCrdt struct {
|
||||
HostEndpoint string `automerge:"hostEndpoint"`
|
||||
WgEndpoint string `automerge:"wgEndpoint"`
|
||||
PublicKey string `automerge:"publicKey"`
|
||||
WgHost string `automerge:"wgHost"`
|
||||
HostEndpoint string `automerge:"hostEndpoint"`
|
||||
WgEndpoint string `automerge:"wgEndpoint"`
|
||||
PublicKey string `automerge:"publicKey"`
|
||||
WgHost string `automerge:"wgHost"`
|
||||
FailedCount *automerge.Counter `automerge:"failedCount"`
|
||||
FailedInt int `automerge:"-"`
|
||||
}
|
||||
|
||||
type MeshCrdt struct {
|
||||
|
@ -30,7 +30,7 @@ type NewCtrlServerParams struct {
|
||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||
ctrlServer := new(MeshCtrlServer)
|
||||
ctrlServer.Client = params.WgClient
|
||||
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient)
|
||||
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf)
|
||||
ctrlServer.Conf = params.Conf
|
||||
|
||||
connManagerParams := conn.NewJwtConnectionManagerParams{
|
||||
|
@ -16,6 +16,7 @@ type MeshNode struct {
|
||||
WgEndpoint string
|
||||
PublicKey string
|
||||
WgHost string
|
||||
FailedCount int
|
||||
}
|
||||
|
||||
type Mesh struct {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"net/rpc"
|
||||
"os"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||
)
|
||||
|
||||
type JoinMeshArgs struct {
|
||||
@ -16,7 +16,7 @@ type JoinMeshArgs struct {
|
||||
}
|
||||
|
||||
type GetMeshReply struct {
|
||||
Nodes []crdt.MeshNodeCrdt
|
||||
Nodes []ctrlserver.MeshNode
|
||||
}
|
||||
|
||||
type ListMeshReply struct {
|
||||
|
@ -9,7 +9,7 @@ func RandomSubsetOfLength[V any](vs []V, num int) []V {
|
||||
selectedIndices := make(map[int]struct{})
|
||||
|
||||
for i := 0; i < num; {
|
||||
if len(selectedIndices) == len(vs) {
|
||||
if len(randomSubset) == len(vs) {
|
||||
return randomSubset
|
||||
}
|
||||
|
||||
|
@ -2,16 +2,20 @@ package manager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type MeshManger struct {
|
||||
Meshes map[string]*crdt.CrdtNodeManager
|
||||
Client *wgctrl.Client
|
||||
Meshes map[string]*crdt.CrdtNodeManager
|
||||
Client *wgctrl.Client
|
||||
HostEndpoint string
|
||||
}
|
||||
|
||||
func (m *MeshManger) MeshExists(meshId string) bool {
|
||||
@ -86,13 +90,7 @@ func (s *MeshManger) EnableInterface(meshId string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err := s.Client.Device(mesh.IfName)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node, contains := crdt.Nodes[dev.PublicKey.String()]
|
||||
node, contains := crdt.Nodes[s.HostEndpoint]
|
||||
|
||||
if !contains {
|
||||
return errors.New("Node does not exist in the mesh")
|
||||
@ -118,6 +116,12 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
||||
return &dev.PublicKey, nil
|
||||
}
|
||||
|
||||
func NewMeshManager(client wgctrl.Client) *MeshManger {
|
||||
return &MeshManger{Meshes: make(map[string]*crdt.CrdtNodeManager), Client: &client}
|
||||
func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger {
|
||||
ip := lib.GetOutboundIP()
|
||||
|
||||
return &MeshManger{
|
||||
Meshes: make(map[string]*crdt.CrdtNodeManager),
|
||||
Client: &client,
|
||||
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
|
||||
}
|
||||
}
|
||||
|
@ -217,11 +217,21 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
|
||||
}
|
||||
|
||||
if mesh != nil {
|
||||
nodes := make([]crdt.MeshNodeCrdt, len(meshSnapshot.Nodes))
|
||||
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
|
||||
|
||||
i := 0
|
||||
for _, n := range meshSnapshot.Nodes {
|
||||
nodes[i] = n
|
||||
failedInt, _ := n.FailedCount.Get()
|
||||
|
||||
node := ctrlserver.MeshNode{
|
||||
HostEndpoint: n.HostEndpoint,
|
||||
WgEndpoint: n.WgEndpoint,
|
||||
PublicKey: n.PublicKey,
|
||||
WgHost: n.WgHost,
|
||||
FailedCount: int(failedInt),
|
||||
}
|
||||
|
||||
nodes[i] = node
|
||||
i += 1
|
||||
}
|
||||
|
||||
|
@ -37,14 +37,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
pubKey, err := s.manager.GetPublicKey(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
excludedNodes := map[string]struct{}{
|
||||
pubKey.String(): {},
|
||||
s.manager.HostEndpoint: {},
|
||||
}
|
||||
|
||||
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
|
||||
|
49
pkg/sync/syncererror.go
Normal file
49
pkg/sync/syncererror.go
Normal file
@ -0,0 +1,49 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"github.com/tim-beatham/wgmesh/pkg/manager"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type SyncErrorHandler interface {
|
||||
Handle(meshId string, endpoint string, err error) bool
|
||||
}
|
||||
|
||||
type SyncErrorHandlerImpl struct {
|
||||
meshManager *manager.MeshManger
|
||||
}
|
||||
|
||||
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
|
||||
mesh := s.meshManager.GetMesh(meshId)
|
||||
|
||||
if mesh == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := mesh.IncrementFailedCount(endpoint)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
||||
errStatus, _ := status.FromError(err)
|
||||
|
||||
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message())
|
||||
|
||||
switch errStatus.Code() {
|
||||
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
||||
return s.incrementFailedCount(meshId, endpoint)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler {
|
||||
return &SyncErrorHandlerImpl{meshManager: m}
|
||||
}
|
@ -17,11 +17,11 @@ type SyncRequester interface {
|
||||
}
|
||||
|
||||
type SyncRequesterImpl struct {
|
||||
server *ctrlserver.MeshCtrlServer
|
||||
server *ctrlserver.MeshCtrlServer
|
||||
errorHdlr SyncErrorHandler
|
||||
}
|
||||
|
||||
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
|
||||
|
||||
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
|
||||
|
||||
if err != nil {
|
||||
@ -77,6 +77,16 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
|
||||
ok := s.errorHdlr.Handle(meshId, endpoint, err)
|
||||
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncMesh: Proactively send a sync request to the other mesh
|
||||
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||
if !s.server.ConnectionManager.HasConnection(endpoint) {
|
||||
@ -92,7 +102,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||
err = peerConnection.Connect()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return s.handleErr(meshId, endpoint, err)
|
||||
}
|
||||
|
||||
client, err := peerConnection.GetClient()
|
||||
@ -126,13 +136,15 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||
_, err = c.SyncMesh(ctx, &syncMeshRequest)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return s.handleErr(meshId, endpoint, err)
|
||||
}
|
||||
|
||||
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||
mesh.DecrementFailedCount(endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
|
||||
return &SyncRequesterImpl{server: s}
|
||||
errorHdlr := NewSyncErrorHandler(s.MeshManager)
|
||||
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user