From 32e7e4c7dfd23bddc3e9eb87d4513ee5cf898330 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 28 Nov 2023 14:42:09 +0000 Subject: [PATCH] main Bugfix. Fixed issue where consistent hashing was not working. --- cmd/wgmeshd/main.go | 12 +++-- pkg/automerge/automerge.go | 91 ++++++++++++++++++++++++++++++++- pkg/automerge/automerge_sync.go | 1 - pkg/ctrlserver/ctrlserver.go | 2 + pkg/mesh/config.go | 34 +++++------- pkg/mesh/manager.go | 71 ++++++++++++++++--------- pkg/mesh/stub_types.go | 9 +++- pkg/mesh/types.go | 2 + pkg/sync/syncer.go | 28 ++++------ pkg/sync/syncscheduler.go | 3 +- pkg/timers/timers.go | 8 --- 11 files changed, 180 insertions(+), 81 deletions(-) diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index aeab44c..0319f52 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -45,21 +45,26 @@ func main() { var robinRpc robin.WgRpc var robinIpc robin.IpcHandler var syncProvider sync.SyncServiceImpl + var syncRequester sync.SyncRequester + var syncer sync.Syncer ctrlServerParams := ctrlserver.NewCtrlServerParams{ Conf: conf, CtrlProvider: &robinRpc, SyncProvider: &syncProvider, Client: client, + OnDelete: func(mp mesh.MeshProvider) { + syncer.SyncMeshes() + }, } ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) syncProvider.Server = ctrlServer - syncRequester := sync.NewSyncRequester(ctrlServer) - syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) + syncRequester = sync.NewSyncRequester(ctrlServer) + syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) + syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer) timestampScheduler := timer.NewTimestampScheduler(ctrlServer) pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) - routeScheduler := timer.NewRouteScheduler(ctrlServer) robinIpcParams := robin.RobinIpcParams{ CtrlServer: ctrlServer, @@ -79,7 +84,6 @@ func main() { go syncScheduler.Run() go timestampScheduler.Run() go pruneScheduler.Run() - go routeScheduler.Run() closeResources := func() { logging.Log.WriteInfof("Closing resources") diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index e57a7ed..c77debc 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -4,7 +4,9 @@ import ( "errors" "fmt" "net" + "slices" "strings" + "sync" "time" "github.com/automerge/automerge-go" @@ -18,6 +20,7 @@ import ( // CrdtMeshManager manages nodes in the crdt mesh type CrdtMeshManager struct { + lock sync.RWMutex MeshId string IfName string Client *wgctrl.Client @@ -39,10 +42,13 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { crdt.Services = make(map[string]string) crdt.Timestamp = time.Now().Unix() + c.lock.Lock() c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) + c.lock.Unlock() } func (c *CrdtMeshManager) isPeer(nodeId string) bool { + c.lock.RLock() node, err := c.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -50,6 +56,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool { } nodeType, err := node.Map().Get("type") + c.lock.RUnlock() if err != nil || nodeType.Kind() != automerge.KindStr { return false @@ -61,6 +68,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool { // isAlive: checks that the node's configuration has been updated // since the rquired keep alive time func (c *CrdtMeshManager) isAlive(nodeId string) bool { + c.lock.RLock() node, err := c.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -68,6 +76,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool { } timestamp, err := node.Map().Get("timestamp") + c.lock.RUnlock() if err != nil || timestamp.Kind() != automerge.KindInt64 { return false @@ -78,7 +87,9 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool { } func (c *CrdtMeshManager) GetPeers() []string { + c.lock.RLock() keys, _ := c.doc.Path("nodes").Map().Keys() + c.lock.RUnlock() keys = lib.Filter(keys, func(publicKey string) bool { return c.isPeer(publicKey) && c.isAlive(publicKey) @@ -97,7 +108,9 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { if c.cache == nil || len(changes) > 0 { c.lastCacheHash = c.LastHash + c.lock.RLock() cache, err := automerge.As[*MeshCrdt](c.doc.Root()) + c.lock.RUnlock() if err != nil { return nil, err @@ -157,6 +170,7 @@ func (m *CrdtMeshManager) NodeExists(key string) bool { } func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(endpoint) if node.Kind() != automerge.KindMap { @@ -168,6 +182,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { } meshNode, err := automerge.As[*MeshNodeCrdt](node) + m.lock.RUnlock() if err != nil { return nil, err @@ -213,7 +228,9 @@ func (m *CrdtMeshManager) SaveChanges() { } func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) + m.lock.RUnlock() if err != nil { return err @@ -223,7 +240,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { return errors.New("node is not a map") } + m.lock.Lock() err = node.Map().Set("timestamp", time.Now().Unix()) + m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Timestamp Updated for %s", nodeId) @@ -233,7 +252,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { } func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) + m.lock.RUnlock() if err != nil { return err @@ -243,7 +264,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro return fmt.Errorf("%s does not exist", nodeId) } + m.lock.Lock() err = node.Map().Set("description", description) + m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Description Updated for %s", nodeId) @@ -253,7 +276,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro } func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) + m.lock.RUnlock() if err != nil { return err @@ -263,7 +288,9 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { return fmt.Errorf("%s does not exist", nodeId) } + m.lock.Lock() err = node.Map().Set("alias", alias) + m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias) @@ -273,13 +300,17 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { } func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) + m.lock.RUnlock() if err != nil || node.Kind() != automerge.KindMap { return fmt.Errorf("AddService: node %s does not exist", nodeId) } + m.lock.RLock() service, err := node.Map().Get("services") + m.lock.RUnlock() if err != nil { return err @@ -289,10 +320,14 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { return fmt.Errorf("AddService: services property does not exist in node") } - return service.Map().Set(key, value) + m.lock.Lock() + err = service.Map().Set(key, value) + m.lock.Unlock() + return err } func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { + m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -308,8 +343,11 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { if service.Kind() != automerge.KindMap { return fmt.Errorf("services property does not exist") } + m.lock.RUnlock() + m.lock.Lock() err = service.Map().Delete(key) + m.lock.Unlock() if err != nil { return fmt.Errorf("service %s does not exist", key) @@ -320,6 +358,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { // AddRoutes: adds routes to the specific nodeId func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { + m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId) @@ -332,16 +371,41 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { } routeMap, err := nodeVal.Map().Get("routes") + m.lock.RUnlock() if err != nil { return err } for _, route := range routes { + prevRoute, err := routeMap.Map().Get(route.GetDestination().String()) + + if prevRoute.Kind() == automerge.KindVoid && err != nil { + path, err := prevRoute.Map().Get("path") + + if err != nil { + return err + } + + if path.Kind() != automerge.KindList { + return fmt.Errorf("path is not a list") + } + + pathStr, err := automerge.As[[]string](path) + + if err != nil { + return err + } + + slices.Equal(route.GetPath(), pathStr) + } + + m.lock.Lock() err = routeMap.Map().Set(route.GetDestination().String(), Route{ Destination: route.GetDestination().String(), Path: route.GetPath(), }) + m.lock.Unlock() if err != nil { return err @@ -351,6 +415,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { } func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { + m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { @@ -372,6 +437,7 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { } routes, err := automerge.As[map[string]Route](routeMap) + m.lock.RUnlock() return lib.MapValues(routes), err } @@ -385,10 +451,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e routes := make(map[string]mesh.Route) + // Add routes that the node directly has for _, route := range node.GetRoutes() { routes[route.GetDestination().String()] = route } + // Work out the other routes in the mesh for _, node := range m.GetPeers() { nodeRoutes, err := m.getRoutes(node) @@ -399,6 +467,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e for _, route := range nodeRoutes { otherRoute, ok := routes[route.GetDestination().String()] + hopCount := route.GetHopCount() + + if node != targetNode { + hopCount += 1 + } + if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() { routes[route.GetDestination().String()] = &Route{ Destination: route.GetDestination().String(), @@ -411,8 +485,16 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e return routes, nil } +func (m *CrdtMeshManager) RemoveNode(nodeId string) error { + m.lock.Lock() + err := m.doc.Path("nodes").Map().Delete(nodeId) + m.lock.Unlock() + return err +} + // DeleteRoutes deletes the specified routes func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { + m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { @@ -424,14 +506,17 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { } routeMap, err := nodeVal.Map().Get("routes") + m.lock.RUnlock() if err != nil { return err } + m.lock.Lock() for _, route := range routes { err = routeMap.Map().Delete(route) } + m.lock.Unlock() return err } @@ -441,6 +526,7 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { } func (m *CrdtMeshManager) Prune(pruneTime int) error { + m.lock.RLock() nodes, err := m.doc.Path("nodes").Get() if err != nil { @@ -452,6 +538,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error { } values, err := nodes.Map().Values() + m.lock.RUnlock() if err != nil { return err @@ -466,7 +553,9 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error { nodeMap := node.Map() + m.lock.RLock() timeStamp, err := nodeMap.Get("timestamp") + m.lock.RUnlock() if err != nil { return err diff --git a/pkg/automerge/automerge_sync.go b/pkg/automerge/automerge_sync.go index 1c6de90..86ca53c 100644 --- a/pkg/automerge/automerge_sync.go +++ b/pkg/automerge/automerge_sync.go @@ -32,7 +32,6 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error { func (a *AutomergeSync) Complete() { logging.Log.WriteInfof("Sync Completed") - a.manager.SaveChanges() } func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 6708956..92de0e4 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -21,6 +21,7 @@ type NewCtrlServerParams struct { CtrlProvider rpc.MeshCtrlServerServer SyncProvider rpc.SyncServiceServer Querier query.Querier + OnDelete func(mesh.MeshProvider) } // Create a new instance of the MeshCtrlServer or error if the @@ -46,6 +47,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { IPAllocator: ipAllocator, InterfaceManipulator: interfaceManipulator, ConfigApplyer: configApplyer, + OnDelete: params.OnDelete, } ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 8cda57a..ff51311 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -9,6 +9,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/route" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -32,10 +33,6 @@ type routeNode struct { route Route } -func (r *routeNode) equals(route2 *routeNode) bool { - return r.gateway == route2.gateway && RouteEquals(r.route, route2.route) -} - func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, peerToClients map[string][]net.IPNet, routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { @@ -63,9 +60,10 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev for _, route := range node.GetRoutes() { bestRoutes := routes[route.GetDestination().String()] + var pickedRoute routeNode if len(bestRoutes) == 1 { - allowedips = append(allowedips, *route.GetDestination()) + pickedRoute = bestRoutes[0] } else if len(bestRoutes) > 1 { keyFunc := func(mn MeshNode) int { pubKey, _ := mn.GetPublicKey() @@ -77,11 +75,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev } // Else there is more than one candidate so consistently hash - pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc) + pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc) + } - if pickedRoute.gateway == pubKey.String() { - allowedips = append(allowedips, *route.GetDestination()) - } + if pickedRoute.gateway == pubKey.String() { + allowedips = append(allowedips, *pickedRoute.route.GetDestination()) } } @@ -101,6 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev Endpoint: endpoint, AllowedIPs: allowedips, PersistentKeepaliveInterval: &keepAlive, + ReplaceAllowedIPs: true, } return &peerConfig, nil @@ -122,14 +121,9 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] for _, node := range mesh.GetNodes() { pubKey, _ := node.GetPublicKey() - meshRoutes, _ := meshProvider.GetRoutes(pubKey.String()) - for _, route := range meshRoutes { + for _, route := range node.GetRoutes() { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { - if prefix == nil || route == nil || route.GetDestination() == nil { - return false - } - return prefix.Contains(route.GetDestination().IP) }) { continue @@ -150,6 +144,8 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] } else if route.GetHopCount() < otherRoute[0].route.GetHopCount() { otherRoute[0] = rn } else if otherRoute[0].route.GetHopCount() == route.GetHopCount() { + logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount()) + logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount()) routes[destination] = append(otherRoute, rn) } } @@ -218,7 +214,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) if !ipNet.Contains(route.IP) { - installedRoutes = append(installedRoutes, lib.Route{ Gateway: n.GetWgHost().IP, Destination: route, @@ -240,13 +235,13 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } - err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) + err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) if err != nil { return err } - return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) + return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) } func (m *WgMeshConfigApplyer) ApplyConfig() error { @@ -275,8 +270,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error { } m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ - ReplacePeers: true, - Peers: make([]wgtypes.PeerConfig, 0), + Peers: make([]wgtypes.PeerConfig, 0), }) return nil diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index e7aa346..e099bb0 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -3,6 +3,7 @@ package mesh import ( "errors" "fmt" + "sync" "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" @@ -18,7 +19,7 @@ type MeshManager interface { AddMesh(params *AddMeshParams) error HasChanges(meshid string) bool GetMesh(meshId string) MeshProvider - GetPublicKey(meshId string) (*wgtypes.Key, error) + GetPublicKey() *wgtypes.Key AddSelf(params *AddSelfParams) error LeaveMesh(meshId string) error GetSelf(meshId string) (MeshNode, error) @@ -38,6 +39,7 @@ type MeshManager interface { } type MeshManagerImpl struct { + lock sync.RWMutex Meshes map[string]MeshProvider RouteManager RouteManager Client *wgctrl.Client @@ -52,6 +54,7 @@ type MeshManagerImpl struct { ipAllocator ip.IPAllocator interfaceManipulator wg.WgInterfaceManipulator Monitor MeshMonitor + OnDelete func(MeshProvider) } // GetRouteManager implements MeshManager. @@ -149,7 +152,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { return "", fmt.Errorf("error creating mesh: %w", err) } + m.lock.Lock() m.Meshes[meshId] = nodeManager + m.lock.Unlock() return meshId, nil } @@ -190,7 +195,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { return err } + m.lock.Lock() m.Meshes[params.MeshId] = meshProvider + m.lock.Unlock() return nil } @@ -206,25 +213,14 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider { } // GetPublicKey: Gets the public key of the WireGuard mesh -func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { +func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key { if s.conf.StubWg { zeroedKey := make([]byte, wgtypes.KeyLen) - return (*wgtypes.Key)(zeroedKey), nil + return (*wgtypes.Key)(zeroedKey) } - mesh, ok := s.Meshes[meshId] - - if !ok { - return nil, errors.New("mesh does not exist") - } - - dev, err := mesh.GetDevice() - - if err != nil { - return nil, err - } - - return &dev.PublicKey, nil + key := s.HostParameters.PrivateKey.PublicKey() + return &key } type AddSelfParams struct { @@ -289,14 +285,29 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { // LeaveMesh leaves the mesh network func (s *MeshManagerImpl) LeaveMesh(meshId string) error { - mesh, exists := s.Meshes[meshId] + mesh := s.GetMesh(meshId) - if !exists { + if mesh == nil { return fmt.Errorf("mesh %s does not exist", meshId) } var err error + s.RouteManager.RemoveRoutes(meshId) + err = mesh.RemoveNode(s.HostParameters.GetPublicKey()) + + if err != nil { + return err + } + + if s.OnDelete != nil { + s.OnDelete(mesh) + } + + s.lock.Lock() + delete(s.Meshes, meshId) + s.lock.Unlock() + if !s.conf.StubWg { device, err := mesh.GetDevice() @@ -311,8 +322,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { } } - err = s.RouteManager.RemoveRoutes(meshId) - delete(s.Meshes, meshId) return err } @@ -348,7 +357,8 @@ func (s *MeshManagerImpl) ApplyConfig() error { } func (s *MeshManagerImpl) SetDescription(description string) error { - for _, mesh := range s.Meshes { + meshes := s.GetMeshes() + for _, mesh := range meshes { if mesh.NodeExists(s.HostParameters.GetPublicKey()) { err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description) @@ -363,7 +373,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error { // SetAlias implements MeshManager. func (s *MeshManagerImpl) SetAlias(alias string) error { - for _, mesh := range s.Meshes { + meshes := s.GetMeshes() + for _, mesh := range meshes { if mesh.NodeExists(s.HostParameters.GetPublicKey()) { err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) @@ -377,7 +388,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error { // UpdateTimeStamp updates the timestamp of this node in all meshes func (s *MeshManagerImpl) UpdateTimeStamp() error { - for _, mesh := range s.Meshes { + meshes := s.GetMeshes() + for _, mesh := range meshes { if mesh.NodeExists(s.HostParameters.GetPublicKey()) { err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey()) @@ -395,7 +407,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client { } func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { - return s.Meshes + meshes := make(map[string]MeshProvider) + + s.lock.RLock() + + for id, mesh := range s.Meshes { + meshes[id] = mesh + } + + s.lock.RUnlock() + return meshes } // Close the mesh manager @@ -432,6 +453,7 @@ type NewMeshManagerParams struct { InterfaceManipulator wg.WgInterfaceManipulator ConfigApplyer MeshConfigApplyer RouteManager RouteManager + OnDelete func(MeshProvider) } // Creates a new instance of a mesh manager with the given parameters @@ -466,5 +488,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager { aliasManager := NewAliasManager() m.Monitor.AddUpdateCallback(aliasManager.AddAliases) m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases) + m.OnDelete = params.OnDelete return m } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 53132ff..23e33eb 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -81,6 +81,11 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// RemoveNode implements MeshProvider. +func (*MeshProviderStub) RemoveNode(nodeId string) error { + panic("unimplemented") +} + func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) { return nil, nil } @@ -287,9 +292,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider { snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} } -func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) { +func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key { key, _ := wgtypes.GenerateKey() - return &key, nil + return &key } func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error { diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 8ab2acf..9f461fc 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -138,6 +138,8 @@ type MeshProvider interface { GetPeers() []string // GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen GetRoutes(targetNode string) (map[string]Route, error) + // RemoveNode(): remove the node from the mesh + RemoveNode(nodeId string) error } // HostParameters contains the IDs of a node diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index ca6fc48..a878694 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -36,28 +36,17 @@ func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof("UPDATING WG CONF") - if s.manager.HasChanges(meshId) { - err := s.manager.ApplyConfig() + s.manager.GetRouteManager().UpdateRoutes() + err := s.manager.ApplyConfig() - if err != nil { - logging.Log.WriteInfof("Failed to update config %w", err) - } + if err != nil { + logging.Log.WriteInfof("Failed to update config %w", err) } + publicKey := s.manager.GetPublicKey() + nodeNames := s.manager.GetMesh(meshId).GetPeers() - self, err := s.manager.GetSelf(meshId) - - if err != nil { - return err - } - - selfPublickey, err := self.GetPublicKey() - - if err != nil { - return err - } - - neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String()) + neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) for _, node := range randomSubset { @@ -68,7 +57,7 @@ func (s *SyncerImpl) Sync(meshId string) error { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { logging.Log.WriteInfof("Sending to random cluster") - interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String()) + interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String()) randomSubset = append(randomSubset, interCluster) } @@ -102,6 +91,7 @@ func (s *SyncerImpl) Sync(meshId string) error { // Check if any changes have occurred and trigger callbacks // if changes have occurred. // return s.manager.GetMonitor().Trigger() + s.manager.GetMesh(meshId).SaveChanges() return nil } diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 3315df8..61bdf41 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -12,7 +12,6 @@ func syncFunction(syncer Syncer) lib.TimerFunc { } } -func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer { - syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester) +func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer { return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) } diff --git a/pkg/timers/timers.go b/pkg/timers/timers.go index 0d9531c..84e1e7c 100644 --- a/pkg/timers/timers.go +++ b/pkg/timers/timers.go @@ -12,11 +12,3 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) } - -func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { - timerFunc := func() error { - return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes() - } - - return *lib.NewTimer(timerFunc, 10) -}