diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index cc16ef7..7a70f16 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -5,6 +5,7 @@ import ( "log" ipcRpc "net/rpc" "os" + "strconv" "github.com/akamensky/argparse" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -68,6 +69,7 @@ func getMesh(client *ipcRpc.Client, meshId string) { fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("Wg IP: " + node.WgHost) + fmt.Println("Failed Count: " + strconv.Itoa(node.FailedCount)) fmt.Println("---") } } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 7aea8ea..7fc4b04 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -18,8 +18,12 @@ type CrdtNodeManager struct { doc *automerge.Doc } +const maxFails = 5 + func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { - c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) + crdt.FailedCount = automerge.NewCounter(0) + c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) + } func (c *CrdtNodeManager) applyWg() error { @@ -115,6 +119,83 @@ func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { return strings.Compare(m1.PublicKey, m2.PublicKey) } +func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error { + node, err := c.doc.Path("nodes").Map().Get(endpoint) + + if err != nil { + return err + } + + counter, err := node.Map().Get("failedCount") + + if err != nil { + return err + } + + err = counter.Counter().Inc(incAmount) + return err +} + +// Increment failed count increments the number of times we have attempted +// to contact the node and it's failed +func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error { + snapshot, err := c.GetCrdt() + + if err != nil { + return err + } + + count, err := snapshot.Nodes[endpoint].FailedCount.Get() + + if err != nil { + return err + } + + if count >= maxFails { + c.removeNode(endpoint) + logging.InfoLog.Printf("Node %s removed from mesh %s", endpoint, c.MeshId) + return nil + } + + if err != nil { + return err + } + + return c.changeFailedCount(c.MeshId, endpoint, +1) +} + +func (c *CrdtNodeManager) removeNode(endpoint string) error { + err := c.doc.Path("nodes").Map().Delete(endpoint) + + if err != nil { + return err + } + + return nil +} + +// Decrement failed count decrements the number of times we have attempted to +// contact the node and it's failed +func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error { + snapshot, err := c.GetCrdt() + + if err != nil { + return err + } + + count, err := snapshot.Nodes[endpoint].FailedCount.Get() + + if err != nil { + return err + } + + if count < 0 { + return nil + } + + return c.changeFailedCount(c.MeshId, endpoint, -1) +} + func updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 58e52f9..60f89b7 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -1,10 +1,14 @@ package crdt +import "github.com/automerge/automerge-go" + type MeshNodeCrdt struct { - HostEndpoint string `automerge:"hostEndpoint"` - WgEndpoint string `automerge:"wgEndpoint"` - PublicKey string `automerge:"publicKey"` - WgHost string `automerge:"wgHost"` + HostEndpoint string `automerge:"hostEndpoint"` + WgEndpoint string `automerge:"wgEndpoint"` + PublicKey string `automerge:"publicKey"` + WgHost string `automerge:"wgHost"` + FailedCount *automerge.Counter `automerge:"failedCount"` + FailedInt int `automerge:"-"` } type MeshCrdt struct { diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index df255b8..cba66ca 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -30,7 +30,7 @@ type NewCtrlServerParams struct { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) ctrlServer.Client = params.WgClient - ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient) + ctrlServer.MeshManager = manager.NewMeshManager(*params.WgClient, *params.Conf) ctrlServer.Conf = params.Conf connManagerParams := conn.NewJwtConnectionManagerParams{ diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index cf9edeb..ba5f2a4 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -16,6 +16,7 @@ type MeshNode struct { WgEndpoint string PublicKey string WgHost string + FailedCount int } type Mesh struct { diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index 2e8316e..76a4f9e 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -7,7 +7,7 @@ import ( "net/rpc" "os" - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" + "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ) type JoinMeshArgs struct { @@ -16,7 +16,7 @@ type JoinMeshArgs struct { } type GetMeshReply struct { - Nodes []crdt.MeshNodeCrdt + Nodes []ctrlserver.MeshNode } type ListMeshReply struct { diff --git a/pkg/lib/random.go b/pkg/lib/random.go index 90a8414..581c99c 100644 --- a/pkg/lib/random.go +++ b/pkg/lib/random.go @@ -9,7 +9,7 @@ func RandomSubsetOfLength[V any](vs []V, num int) []V { selectedIndices := make(map[int]struct{}) for i := 0; i < num; { - if len(selectedIndices) == len(vs) { + if len(randomSubset) == len(vs) { return randomSubset } diff --git a/pkg/manager/mesh_manager.go b/pkg/manager/mesh_manager.go index b3e834f..a120976 100644 --- a/pkg/manager/mesh_manager.go +++ b/pkg/manager/mesh_manager.go @@ -2,16 +2,20 @@ package manager import ( "errors" + "fmt" crdt "github.com/tim-beatham/wgmesh/pkg/automerge" + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type MeshManger struct { - Meshes map[string]*crdt.CrdtNodeManager - Client *wgctrl.Client + Meshes map[string]*crdt.CrdtNodeManager + Client *wgctrl.Client + HostEndpoint string } func (m *MeshManger) MeshExists(meshId string) bool { @@ -86,13 +90,7 @@ func (s *MeshManger) EnableInterface(meshId string) error { return err } - dev, err := s.Client.Device(mesh.IfName) - - if err != nil { - return err - } - - node, contains := crdt.Nodes[dev.PublicKey.String()] + node, contains := crdt.Nodes[s.HostEndpoint] if !contains { return errors.New("Node does not exist in the mesh") @@ -118,6 +116,12 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) { return &dev.PublicKey, nil } -func NewMeshManager(client wgctrl.Client) *MeshManger { - return &MeshManger{Meshes: make(map[string]*crdt.CrdtNodeManager), Client: &client} +func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger { + ip := lib.GetOutboundIP() + + return &MeshManger{ + Meshes: make(map[string]*crdt.CrdtNodeManager), + Client: &client, + HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort), + } } diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index b421b1c..bff6b53 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -217,11 +217,21 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error { } if mesh != nil { - nodes := make([]crdt.MeshNodeCrdt, len(meshSnapshot.Nodes)) + nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes)) i := 0 for _, n := range meshSnapshot.Nodes { - nodes[i] = n + failedInt, _ := n.FailedCount.Get() + + node := ctrlserver.MeshNode{ + HostEndpoint: n.HostEndpoint, + WgEndpoint: n.WgEndpoint, + PublicKey: n.PublicKey, + WgHost: n.WgHost, + FailedCount: int(failedInt), + } + + nodes[i] = node i += 1 } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 968219d..723c2b4 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -37,14 +37,8 @@ func (s *SyncerImpl) Sync(meshId string) error { return err } - pubKey, err := s.manager.GetPublicKey(meshId) - - if err != nil { - return err - } - excludedNodes := map[string]struct{}{ - pubKey.String(): {}, + s.manager.HostEndpoint: {}, } meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go new file mode 100644 index 0000000..04f9cbc --- /dev/null +++ b/pkg/sync/syncererror.go @@ -0,0 +1,49 @@ +package sync + +import ( + logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/manager" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type SyncErrorHandler interface { + Handle(meshId string, endpoint string, err error) bool +} + +type SyncErrorHandlerImpl struct { + meshManager *manager.MeshManger +} + +func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool { + mesh := s.meshManager.GetMesh(meshId) + + if mesh == nil { + return false + } + + err := mesh.IncrementFailedCount(endpoint) + + if err != nil { + return false + } + + return true +} + +func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { + errStatus, _ := status.FromError(err) + + logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message()) + + switch errStatus.Code() { + case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: + return s.incrementFailedCount(meshId, endpoint) + } + + return false +} + +func NewSyncErrorHandler(m *manager.MeshManger) SyncErrorHandler { + return &SyncErrorHandlerImpl{meshManager: m} +} diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index 3bde281..e798062 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -17,11 +17,11 @@ type SyncRequester interface { } type SyncRequesterImpl struct { - server *ctrlserver.MeshCtrlServer + server *ctrlserver.MeshCtrlServer + errorHdlr SyncErrorHandler } func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error { - peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint) if err != nil { @@ -77,6 +77,16 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { return err } +func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error { + ok := s.errorHdlr.Handle(meshId, endpoint, err) + + if ok { + return nil + } + + return err +} + // SyncMesh: Proactively send a sync request to the other mesh func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { if !s.server.ConnectionManager.HasConnection(endpoint) { @@ -92,7 +102,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { err = peerConnection.Connect() if err != nil { - return err + return s.handleErr(meshId, endpoint, err) } client, err := peerConnection.GetClient() @@ -126,13 +136,15 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { _, err = c.SyncMesh(ctx, &syncMeshRequest) if err != nil { - return err + return s.handleErr(meshId, endpoint, err) } logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId) + mesh.DecrementFailedCount(endpoint) return nil } func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { - return &SyncRequesterImpl{server: s} + errorHdlr := NewSyncErrorHandler(s.MeshManager) + return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr} }