mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-06-21 12:32:18 +02:00
Automatically remove nodes from the mesh after a
certain threshold.
This commit is contained in:
parent
c200544cee
commit
976dbf2613
@ -5,6 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
ipcRpc "net/rpc"
|
ipcRpc "net/rpc"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/akamensky/argparse"
|
"github.com/akamensky/argparse"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
||||||
@ -68,6 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
|
|||||||
fmt.Println("Control Endpoint: " + node.HostEndpoint)
|
fmt.Println("Control Endpoint: " + node.HostEndpoint)
|
||||||
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
|
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
|
||||||
fmt.Println("Wg IP: " + node.WgHost)
|
fmt.Println("Wg IP: " + node.WgHost)
|
||||||
|
fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount))
|
||||||
fmt.Println("---")
|
fmt.Println("---")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,8 +18,12 @@ type CrdtNodeManager struct {
|
|||||||
doc *automerge.Doc
|
doc *automerge.Doc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxFails = 5
|
||||||
|
|
||||||
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
|
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
|
||||||
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
crdt.FailedCount = automerge.NewCounter(0)
|
||||||
|
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CrdtNodeManager) applyWg() error {
|
func (c *CrdtNodeManager) applyWg() error {
|
||||||
@ -115,6 +119,83 @@ func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
|
|||||||
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
|
||||||
|
node, err := c.doc.Path("nodes").Map().Get(endpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
counter, err := node.Map().Get("failedCount")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = counter.Counter().Inc(incAmount)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment failed count increments the number of times we have attempted
|
||||||
|
// to contact the node and it's failed
|
||||||
|
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
|
||||||
|
snapshot, err := c.GetCrdt()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count >= maxFails {
|
||||||
|
c.removeNode(endpoint)
|
||||||
|
logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.changeFailedCount(c.MeshId, endpoint, +1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CrdtNodeManager) removeNode(endpoint string) error {
|
||||||
|
err := c.doc.Path("nodes").Map().Delete(endpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement failed count decrements the number of times we have attempted to
|
||||||
|
// contact the node and it's failed
|
||||||
|
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
|
||||||
|
snapshot, err := c.GetCrdt()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := snapshot.Nodes[endpoint].FailedCount.Get()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.changeFailedCount(c.MeshId, endpoint, -1)
|
||||||
|
}
|
||||||
|
|
||||||
func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
||||||
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
||||||
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
package crdt
|
package crdt
|
||||||
|
|
||||||
|
import "github.com/automerge/automerge-go"
|
||||||
|
|
||||||
type MeshNodeCrdt struct {
|
type MeshNodeCrdt struct {
|
||||||
HostEndpoint string `automerge:"hostEndpoint"`
|
HostEndpoint string `automerge:"hostEndpoint"`
|
||||||
WgEndpoint string `automerge:"wgEndpoint"`
|
WgEndpoint string `automerge:"wgEndpoint"`
|
||||||
PublicKey string `automerge:"publicKey"`
|
PublicKey string `automerge:"publicKey"`
|
||||||
WgHost string `automerge:"wgHost"`
|
WgHost string `automerge:"wgHost"`
|
||||||
|
FailedCount *automerge.Counter `automerge:"failedCount"`
|
||||||
|
FailedInt int `automerge:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MeshCrdt struct {
|
type MeshCrdt struct {
|
||||||
|
@ -30,7 +30,7 @@ type NewCtrlServerParams struct {
|
|||||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||||
ctrlServer := new(MeshCtrlServer)
|
ctrlServer := new(MeshCtrlServer)
|
||||||
ctrlServer.Client = params.WgClient
|
ctrlServer.Client = params.WgClient
|
||||||
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient)
|
ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf)
|
||||||
ctrlServer.Conf = params.Conf
|
ctrlServer.Conf = params.Conf
|
||||||
|
|
||||||
connManagerParams := conn.NewJwtConnectionManagerParams{
|
connManagerParams := conn.NewJwtConnectionManagerParams{
|
||||||
|
@ -16,6 +16,7 @@ type MeshNode struct {
|
|||||||
WgEndpoint string
|
WgEndpoint string
|
||||||
PublicKey string
|
PublicKey string
|
||||||
WgHost string
|
WgHost string
|
||||||
|
FailedCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Mesh struct {
|
type Mesh struct {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"net/rpc"
|
"net/rpc"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JoinMeshArgs struct {
|
type JoinMeshArgs struct {
|
||||||
@ -16,7 +16,7 @@ type JoinMeshArgs struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetMeshReply struct {
|
type GetMeshReply struct {
|
||||||
Nodes []crdt.MeshNodeCrdt
|
Nodes []ctrlserver.MeshNode
|
||||||
}
|
}
|
||||||
|
|
||||||
type ListMeshReply struct {
|
type ListMeshReply struct {
|
||||||
|
@ -9,7 +9,7 @@ func RandomSubsetOfLength[V any](vs []V, num int) []V {
|
|||||||
selectedIndices := make(map[int]struct{})
|
selectedIndices := make(map[int]struct{})
|
||||||
|
|
||||||
for i := 0; i < num; {
|
for i := 0; i < num; {
|
||||||
if len(selectedIndices) == len(vs) {
|
if len(randomSubset) == len(vs) {
|
||||||
return randomSubset
|
return randomSubset
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,8 +2,11 @@ package manager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@ -12,6 +15,7 @@ import (
|
|||||||
type MeshManger struct {
|
type MeshManger struct {
|
||||||
Meshes map[string]*crdt.CrdtNodeManager
|
Meshes map[string]*crdt.CrdtNodeManager
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
|
HostEndpoint string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MeshManger) MeshExists(meshId string) bool {
|
func (m *MeshManger) MeshExists(meshId string) bool {
|
||||||
@ -86,13 +90,7 @@ func (s *MeshManger) EnableInterface(meshId string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err := s.Client.Device(mesh.IfName)
|
node, contains := crdt.Nodes[s.HostEndpoint]
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
node, contains := crdt.Nodes[dev.PublicKey.String()]
|
|
||||||
|
|
||||||
if !contains {
|
if !contains {
|
||||||
return errors.New("Node does not exist in the mesh")
|
return errors.New("Node does not exist in the mesh")
|
||||||
@ -118,6 +116,12 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
|||||||
return &dev.PublicKey, nil
|
return &dev.PublicKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMeshManager(client wgctrl.Client) *MeshManger {
|
func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger {
|
||||||
return &MeshManger{Meshes: make(map[string]*crdt.CrdtNodeManager), Client: &client}
|
ip := lib.GetOutboundIP()
|
||||||
|
|
||||||
|
return &MeshManger{
|
||||||
|
Meshes: make(map[string]*crdt.CrdtNodeManager),
|
||||||
|
Client: &client,
|
||||||
|
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -217,11 +217,21 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mesh != nil {
|
if mesh != nil {
|
||||||
nodes := make([]crdt.MeshNodeCrdt, len(meshSnapshot.Nodes))
|
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
for _, n := range meshSnapshot.Nodes {
|
for _, n := range meshSnapshot.Nodes {
|
||||||
nodes[i] = n
|
failedInt, _ := n.FailedCount.Get()
|
||||||
|
|
||||||
|
node := ctrlserver.MeshNode{
|
||||||
|
HostEndpoint: n.HostEndpoint,
|
||||||
|
WgEndpoint: n.WgEndpoint,
|
||||||
|
PublicKey: n.PublicKey,
|
||||||
|
WgHost: n.WgHost,
|
||||||
|
FailedCount: int(failedInt),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes[i] = node
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,14 +37,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := s.manager.GetPublicKey(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
excludedNodes := map[string]struct{}{
|
excludedNodes := map[string]struct{}{
|
||||||
pubKey.String(): {},
|
s.manager.HostEndpoint: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
|
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
|
||||||
|
49
pkg/sync/syncererror.go
Normal file
49
pkg/sync/syncererror.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/manager"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SyncErrorHandler interface {
|
||||||
|
Handle(meshId string, endpoint string, err error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncErrorHandlerImpl struct {
|
||||||
|
meshManager *manager.MeshManger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
|
||||||
|
mesh := s.meshManager.GetMesh(meshId)
|
||||||
|
|
||||||
|
if mesh == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mesh.IncrementFailedCount(endpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
||||||
|
errStatus, _ := status.FromError(err)
|
||||||
|
|
||||||
|
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message())
|
||||||
|
|
||||||
|
switch errStatus.Code() {
|
||||||
|
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
||||||
|
return s.incrementFailedCount(meshId, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler {
|
||||||
|
return &SyncErrorHandlerImpl{meshManager: m}
|
||||||
|
}
|
@ -18,10 +18,10 @@ type SyncRequester interface {
|
|||||||
|
|
||||||
type SyncRequesterImpl struct {
|
type SyncRequesterImpl struct {
|
||||||
server *ctrlserver.MeshCtrlServer
|
server *ctrlserver.MeshCtrlServer
|
||||||
|
errorHdlr SyncErrorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
|
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
|
||||||
|
|
||||||
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
|
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -77,6 +77,16 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
|
||||||
|
ok := s.errorHdlr.Handle(meshId, endpoint, err)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// SyncMesh: Proactively send a sync request to the other mesh
|
// SyncMesh: Proactively send a sync request to the other mesh
|
||||||
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||||
if !s.server.ConnectionManager.HasConnection(endpoint) {
|
if !s.server.ConnectionManager.HasConnection(endpoint) {
|
||||||
@ -92,7 +102,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
err = peerConnection.Connect()
|
err = peerConnection.Connect()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return s.handleErr(meshId, endpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := peerConnection.GetClient()
|
client, err := peerConnection.GetClient()
|
||||||
@ -126,13 +136,15 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
_, err = c.SyncMesh(ctx, &syncMeshRequest)
|
_, err = c.SyncMesh(ctx, &syncMeshRequest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return s.handleErr(meshId, endpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||||
|
mesh.DecrementFailedCount(endpoint)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
|
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
|
||||||
return &SyncRequesterImpl{server: s}
|
errorHdlr := NewSyncErrorHandler(s.MeshManager)
|
||||||
|
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user