diff --git a/cmd/wgmeshd/configuration.yaml b/cmd/wgmeshd/configuration.yaml index 103e958..bde7ff8 100644 --- a/cmd/wgmeshd/configuration.yaml +++ b/cmd/wgmeshd/configuration.yaml @@ -10,4 +10,5 @@ syncRate: 1 interClusterChance: 0.15 branchRate: 3 infectionCount: 3 -keepAliveTime: 60 \ No newline at end of file +keepAliveTime: 10 +pruneTime: 20 \ No newline at end of file diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index c879d07..facc6bf 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -9,6 +9,7 @@ import ( ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/timestamp" @@ -44,12 +45,13 @@ func main() { SyncProvider: &syncProvider, Client: client, } - ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) + syncProvider.Server = ctrlServer syncRequester := sync.NewSyncRequester(ctrlServer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer) + pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) robinIpcParams := robin.RobinIpcParams{ CtrlServer: ctrlServer, @@ -68,6 +70,7 @@ func main() { go ipc.RunIpcHandler(&robinIpc) go syncScheduler.Run() go timestampScheduler.Run() + go pruneScheduler.Run() closeResources := func() { logging.Log.WriteInfof("Closing resources") diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 1443a9a..3d20024 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -34,10 +34,10 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { panic("node must be of type *MeshNodeCrdt") } + crdt.Routes = make(map[string]interface{}) + crdt.Timestamp = time.Now().Unix() c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) - nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint) - nodeVal.Map().Set("routes", automerge.NewMap()) } // GetMesh(): Converts the document into a struct @@ -204,7 +204,6 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { return err } } - return nil } @@ -212,63 +211,53 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { return NewAutomergeSync(m) } -// getHealthMap returns the health map from the automerge CRDT -func (m *CrdtMeshManager) getHealthMap(nodeId string) (*automerge.Map, error) { - node, err := m.doc.Path("nodes").Map().Get(nodeId) - - if err != nil { - return nil, err - } - - if node.Kind() != automerge.KindMap { - return nil, errors.New("node should be a map") - } - - nodeMap := node.Map() - - health, err := nodeMap.Get("health") - - if err != nil { - return nil, err - } - - if health.Kind() != automerge.KindMap { - return nil, errors.New("health should be a map") - } - - healthMap := health.Map() - return healthMap, nil -} - -// DecrementHealth: indicates that the current node has voted that the health is down -func (m *CrdtMeshManager) DecrementHealth(nodeId string, selfId string) error { - healthMap, err := m.getHealthMap(nodeId) +func (m *CrdtMeshManager) Prune(pruneTime int) error { + nodes, err := m.doc.Path("nodes").Get() if err != nil { return err } - err = healthMap.Set(selfId, struct{}{}) - - if err != nil { - logging.Log.WriteErrorf(err.Error()) + if nodes.Kind() != automerge.KindMap { + return errors.New("node must be a map") } - return nil -} - -// IncrementHealth: indicates that the current node thinks that the noden is up -func (m *CrdtMeshManager) IncrementHealth(nodeId string, selfId string) error { - healthMap, err := m.getHealthMap(nodeId) + values, err := nodes.Map().Values() if err != nil { return err } - err = healthMap.Delete(selfId) + deletionNodes := make([]string, 0) - if err != nil { - logging.Log.WriteErrorf(err.Error()) + for nodeId, node := range values { + if node.Kind() != automerge.KindMap { + return errors.New("node must be a map") + } + + nodeMap := node.Map() + + timeStamp, err := nodeMap.Get("timestamp") + + if err != nil { + return err + } + + if timeStamp.Kind() != automerge.KindInt64 { + return errors.New("timestamp is not int64") + } + + timeValue := timeStamp.Int64() + nowValue := time.Now().Unix() + + if nowValue-timeValue >= int64(pruneTime) { + deletionNodes = append(deletionNodes, nodeId) + } + } + + for _, node := range deletionNodes { + logging.Log.WriteInfof("Pruning %s", node) + nodes.Map().Delete(node) } return nil @@ -322,10 +311,6 @@ func (m *MeshNodeCrdt) GetIdentifier() string { return strings.Join(constituents, ":") } -func (m *MeshNodeCrdt) GetHealth() int { - return len(m.Health) -} - func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { nodes := make(map[string]mesh.MeshNode) diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index f4216e9..60315fe 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -9,7 +9,6 @@ type MeshNodeCrdt struct { Timestamp int64 `automerge:"timestamp"` Routes map[string]interface{} `automerge:"routes"` Description string `automerge:"description"` - Health map[string]interface{} `automerge:"health"` } // MeshCrdt: Represents the mesh network as a whole diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 146d479..26e1dfb 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -43,10 +43,12 @@ type WgMeshConfiguration struct { BranchRate int `yaml:"branchRate"` // InfectionCount number of times we sync before we can no longer catch the udpate InfectionCount int `yaml:"infectionCount"` - // KeepAliveTime + // KeepAliveTime number of seconds before we update node indicating that we are still alive KeepAliveTime int `yaml:"keepAliveTime"` - // Timeout number of seconds before we update node indicating that we are still alive + // Timeout number of seconds before we consider the node as dead Timeout int `yaml:"timeout"` + // PruneTime number of seconds before we consider the 'node' as dead + PruneTime int `yaml:"pruneTime"` } func ValidateConfiguration(c *WgMeshConfiguration) error { @@ -110,9 +112,21 @@ func ValidateConfiguration(c *WgMeshConfiguration) error { } } - if c.Timeout <= 1 { + if c.Timeout < 1 { return &WgMeshConfigurationError{ - msg: "Timeout should be less than or equal to 1", + msg: "Timeout should be greater than or equal to 1", + } + } + + if c.PruneTime <= 1 { + return &WgMeshConfigurationError{ + msg: "Prune time cannot be <= 1", + } + } + + if c.KeepAliveTime <= 1 { + return &WgMeshConfigurationError{ + msg: "Prune time cannot be less than keep alive time", } } diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index ba7b665..f6436bf 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -15,8 +15,10 @@ func getExampleConfiguration() *WgMeshConfiguration { SyncRate: 1, InterClusterChance: 0.1, BranchRate: 2, - KeepAliveTime: 1, + KeepAliveTime: 4, InfectionCount: 1, + Timeout: 2, + PruneTime: 20, } } @@ -128,3 +130,36 @@ func TestValidCOnfiguration(t *testing.T) { t.Error(err) } } + +func TestTimeout(t *testing.T) { + conf := getExampleConfiguration() + conf.Timeout = 0 + + err := ValidateConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestPruneTimeZero(t *testing.T) { + conf := getExampleConfiguration() + conf.PruneTime = 0 + + err := ValidateConfiguration(conf) + + if err == nil { + t.Fatalf(`Error should be thrown`) + } +} + +func TestPruneTimeLessThanKeepAliveTime(t *testing.T) { + conf := getExampleConfiguration() + conf.PruneTime = 1 + + err := ValidateConfiguration(conf) + + if err == nil { + t.Fatalf(`Error should be thrown`) + } +} diff --git a/pkg/lib/timer.go b/pkg/lib/timer.go new file mode 100644 index 0000000..fc0d9f4 --- /dev/null +++ b/pkg/lib/timer.go @@ -0,0 +1,42 @@ +package lib + +import "time" + +type TimerFunc = func() error + +type Timer struct { + f TimerFunc + quit chan struct{} + updateRate int +} + +func (t *Timer) Run() error { + ticker := time.NewTicker(time.Duration(t.updateRate) * time.Second) + + t.quit = make(chan struct{}) + + for { + select { + case <-ticker.C: + err := t.f() + + if err != nil { + return err + } + case <-t.quit: + break + } + } +} + +func (t *Timer) Stop() error { + close(t.quit) + return nil +} + +func NewTimer(f TimerFunc, updateRate int) *Timer { + return &Timer{ + f: f, + updateRate: updateRate, + } +} diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 5b85f6c..ed3698f 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -28,6 +28,7 @@ type MeshManager interface { UpdateTimeStamp() error GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider + Prune() error } type MeshManagerImpl struct { @@ -46,6 +47,19 @@ type MeshManagerImpl struct { interfaceManipulator wg.WgInterfaceManipulator } +// Prune implements MeshManager. +func (m *MeshManagerImpl) Prune() error { + for _, mesh := range m.Meshes { + err := mesh.Prune(m.conf.PruneTime) + + if err != nil { + return err + } + } + + return nil +} + // CreateMesh: Creates a new mesh, stores it and returns the mesh id func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { meshId, err := m.idGenerator.GetId() @@ -216,7 +230,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { _, exists := s.Meshes[meshId] if !exists { - return errors.New(fmt.Sprintf("mesh %s does not exist", meshId)) + return fmt.Errorf("mesh %s does not exist", meshId) } // For now just delete the mesh with the ID. @@ -228,7 +242,7 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { meshInstance, ok := s.Meshes[meshId] if !ok { - return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId)) + return nil, fmt.Errorf("mesh %s does not exist", meshId) } snapshot, err := meshInstance.GetMesh() diff --git a/pkg/mesh/pruner.go b/pkg/mesh/pruner.go new file mode 100644 index 0000000..904bf22 --- /dev/null +++ b/pkg/mesh/pruner.go @@ -0,0 +1,16 @@ +package mesh + +import ( + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" +) + +func pruneFunction(m MeshManager) lib.TimerFunc { + return func() error { + return m.Prune() + } +} + +func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer { + return lib.NewTimer(pruneFunction(m), conf.PruneTime/2) +} diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 9605d88..aa3de11 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -21,11 +21,6 @@ type MeshNodeStub struct { description string } -// GetHealth implements MeshNode. -func (*MeshNodeStub) GetHealth() int { - return 5 -} - func (m *MeshNodeStub) GetHostEndpoint() string { return m.hostEndpoint } @@ -71,13 +66,8 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } -// DecrementHealth implements MeshProvider. -func (*MeshProviderStub) DecrementHealth(nodeId string, selfId string) error { - return nil -} - -// IncrementHealth implements MeshProvider. -func (*MeshProviderStub) IncrementHealth(nodeId string, selfId string) error { +// Prune implements MeshProvider. +func (*MeshProviderStub) Prune(pruneAmount int) error { return nil } @@ -169,10 +159,18 @@ func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error { return nil } +func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) { +} + type MeshManagerStub struct { meshes map[string]MeshProvider } +// Prune implements MeshManager. +func (*MeshManagerStub) Prune() error { + return nil +} + func NewMeshManagerStub() MeshManager { return &MeshManagerStub{meshes: make(map[string]MeshProvider)} } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index ef6e1e5..7bffd88 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -28,8 +28,6 @@ type MeshNode interface { GetIdentifier() string // GetDescription: returns the description for this node GetDescription() string - // GetHealth: returns the health score for this mesh node - GetHealth() int } type MeshSnapshot interface { @@ -70,12 +68,9 @@ type MeshProvider interface { GetSyncer() MeshSyncer // SetDescription: sets the description of this automerge data type SetDescription(nodeId string, description string) error - // DecrementHealth: indicates that the node with selfId thinks that the node - // is down - DecrementHealth(nodeId string, selfId string) error - // IncrementHealth: indicates that the node is up and so increment the health of the - // node - IncrementHealth(nodeId string, selfId string) error + // Prune: prunes all nodes that have not updated their timestamp in + // pruneAmount seconds + Prune(pruneAmount int) error } // HostParameters contains the IDs of a node diff --git a/pkg/query/query.go b/pkg/query/query.go index 95372b3..0978f08 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -31,7 +31,6 @@ type QueryNode struct { Timestamp int64 `json:"timestmap"` Description string `json:"description"` Routes []string `json:"routes"` - Health int `json:"health"` } func (m *QueryError) Error() string { @@ -77,7 +76,6 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { queryNode.Timestamp = node.GetTimeStamp() queryNode.Routes = node.GetRoutes() queryNode.Description = node.GetDescription() - queryNode.Health = node.GetHealth() return queryNode } diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index 7a5b529..4458c9c 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -24,13 +24,13 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri return false } - self, err := s.meshManager.GetSelf(meshId) + // self, err := s.meshManager.GetSelf(meshId) - if err != nil { - return false - } + // if err != nil { + // return false + // } - mesh.DecrementHealth(meshId, self.GetHostEndpoint()) + // mesh.DecrementHealth(endpoint, self.GetHostEndpoint()) return true } diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index 221544e..3c40d72 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -100,18 +100,6 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { return s.handleErr(meshId, endpoint, err) } - self, err := s.server.MeshManager.GetSelf(mesh.GetMeshId()) - - if err != nil { - return err - } - - err = mesh.IncrementHealth(meshId, self.GetHostEndpoint()) - - if err != nil { - return err - } - logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) return nil } diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 8d0856e..ddec396 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -1,10 +1,8 @@ package sync import ( - "time" - "github.com/tim-beatham/wgmesh/pkg/ctrlserver" - logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/lib" ) // SyncScheduler: Loops through all nodes in the mesh and runs a schedule to @@ -22,34 +20,13 @@ type SyncSchedulerImpl struct { } // Run implements SyncScheduler. -func (s *SyncSchedulerImpl) Run() error { - ticker := time.NewTicker(time.Duration(s.server.Conf.SyncRate) * time.Second) - - quit := make(chan struct{}) - s.quit = quit - - for { - select { - case <-ticker.C: - err := s.syncer.SyncMeshes() - - if err != nil { - logging.Log.WriteErrorf(err.Error()) - } - break - case <-quit: - break - } +func syncFunction(syncer Syncer) lib.TimerFunc { + return func() error { + return syncer.SyncMeshes() } } -// Stop implements SyncScheduler. -func (s *SyncSchedulerImpl) Stop() error { - close(s.quit) - return nil -} - -func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) SyncScheduler { +func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer { syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester) - return &SyncSchedulerImpl{server: s, syncer: syncer} + return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) } diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go index 0ca38e9..4c41289 100644 --- a/pkg/timestamp/timestamp.go +++ b/pkg/timestamp/timestamp.go @@ -1,51 +1,14 @@ package timestamp import ( - "time" - "github.com/tim-beatham/wgmesh/pkg/ctrlserver" - logging "github.com/tim-beatham/wgmesh/pkg/log" - "github.com/tim-beatham/wgmesh/pkg/mesh" + "github.com/tim-beatham/wgmesh/pkg/lib" ) -type TimestampScheduler interface { - Run() error - Stop() error -} - -type TimeStampSchedulerImpl struct { - meshManager mesh.MeshManager - updateRate int - quit chan struct{} -} - -func (s *TimeStampSchedulerImpl) Run() error { - ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second) - - s.quit = make(chan struct{}) - - for { - select { - case <-ticker.C: - err := s.meshManager.UpdateTimeStamp() - - if err != nil { - logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error()) - } - case <-s.quit: - break - } +func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { + timerFunc := func() error { + return ctrlServer.MeshManager.UpdateTimeStamp() } -} -func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler { - return &TimeStampSchedulerImpl{ - meshManager: ctrlServer.MeshManager, - updateRate: ctrlServer.Conf.KeepAliveTime, - } -} - -func (s *TimeStampSchedulerImpl) Stop() error { - close(s.quit) - return nil + return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) }