From 2e6aed6f93e35e486cf8575c4e622b56a5cab199 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Sat, 30 Dec 2023 00:44:57 +0000 Subject: [PATCH] main - Fixing issue with nil pointer de-reference due to bad design of mesh manager. - Going forward all references to GetSelf should be depracated. It introduces a race condition when leaving a mesh network --- pkg/mesh/config.go | 37 ++++++++++++++-------- pkg/mesh/route.go | 19 +++++++++--- pkg/mesh/types.go | 4 +++ pkg/sync/syncer.go | 77 +++++++++++++++++++++++----------------------- 4 files changed, 80 insertions(+), 57 deletions(-) diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 75b8d77..9d37211 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -119,7 +119,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg // getRoutes: finds the routes with the least hop distance. If more than one route exists // consistently hash to evenly spread the distribution of traffic -func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode { +func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) { mesh, _ := meshProvider.GetMesh() routes := make(map[string][]routeNode) @@ -158,17 +158,19 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] // Client's only acessible by another peer if node.GetType() == conf.CLIENT_ROLE { peer := m.getCorrespondingPeer(peers, node) - self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId()) + self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String()) + + if err != nil { + return nil, err + } - // If the node isn't the self use that peer as the gateway if !NodeEquals(peer, self) { peerPub, _ := peer.GetPublicKey() rn.gateway = peerPub.String() rn.route = &RouteStub{ Destination: rn.route.GetDestination(), HopCount: rn.route.GetHopCount() + 1, - // Append the path to this peer - Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()), + Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()), } } } @@ -185,7 +187,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] } } - return routes + return routes, nil } // getCorrespondignPeer: gets the peer corresponding to the client @@ -219,7 +221,6 @@ type GetConfigParams struct { } func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) { - self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) ula := &ip.ULABuilder{} meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId()) @@ -235,6 +236,8 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes }) routes = append(routes, *meshNet) + self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String()) + if err != nil { return nil, err } @@ -302,7 +305,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C peerToClients := make(map[string][]net.IPNet) installedRoutes := make([]lib.Route, 0) peerConfigs := make([]wgtypes.PeerConfig, 0) - self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) + self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String()) if err != nil { return nil, err @@ -393,7 +396,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string] return mn.GetType() == conf.CLIENT_ROLE }) - self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + self, err := mesh.GetNode(m.meshManager.GetPublicKey().String()) if err != nil { return err @@ -432,11 +435,15 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string] return nil } -func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { +func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) { allRoutes := make(map[string][]routeNode) for _, mesh := range m.meshManager.GetMeshes() { - routes := m.getRoutes(mesh) + routes, err := m.getRoutes(mesh) + + if err != nil { + return nil, err + } for destination, route := range routes { _, ok := allRoutes[destination] @@ -454,11 +461,15 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { } } - return allRoutes + return allRoutes, nil } func (m *WgMeshConfigApplyer) ApplyConfig() error { - allRoutes := m.getAllRoutes() + allRoutes, err := m.getAllRoutes() + + if err != nil { + return err + } for _, mesh := range m.meshManager.GetMeshes() { err := m.updateWgConf(mesh, allRoutes) diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index f7fd4c4..e33a349 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -24,7 +24,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error { continue } - self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) + self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String()) if err != nil { return err @@ -90,11 +90,20 @@ func (r *RouteManagerImpl) UpdateRoutes() error { // Calculate the set different of each, working out routes to remove and to keep. for meshId, meshRoutes := range routes { - mesh := r.meshManager.GetMesh(meshId) - self, _ := r.meshManager.GetSelf(meshId) - toRemove := make([]Route, 0) + mesh := meshes[meshId] - prevRoutes, _ := mesh.GetRoutes(NodeID(self)) + self, err := mesh.GetNode(r.meshManager.GetPublicKey().String()) + + if err != nil { + return err + } + + toRemove := make([]Route, 0) + prevRoutes, err := mesh.GetRoutes(NodeID(self)) + + if err != nil { + return err + } for _, route := range prevRoutes { if !lib.Contains(meshRoutes, func(r Route) bool { diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index bbe9afc..9f2417b 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -81,6 +81,10 @@ func NodeEquals(node1, node2 MeshNode) bool { key1, _ := node1.GetPublicKey() key2, _ := node2.GetPublicKey() + if node1 == nil || node2 == nil { + return false + } + return key1.String() == key2.String() } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 896b33f..71704c5 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -1,7 +1,6 @@ package sync import ( - "errors" "fmt" "io" "math/rand" @@ -16,7 +15,7 @@ import ( // Syncer: picks random nodes from the meshs type Syncer interface { - Sync(meshId string) error + Sync(theMesh mesh.MeshProvider) error SyncMeshes() error } @@ -30,21 +29,33 @@ type SyncerImpl struct { lastSync map[string]uint64 } -// Sync: Sync random nodes -func (s *SyncerImpl) Sync(meshId string) error { - // Self can be nil if the node is removed - self, _ := s.manager.GetSelf(meshId) +// Sync: Sync with random nodes +func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error { + if correspondingMesh == nil { + return fmt.Errorf("mesh provided was nil cannot sync nil mesh") + } - correspondingMesh := s.manager.GetMesh(meshId) + // Self can be nil if the node is removed + selfID := s.manager.GetPublicKey() + self, _ := correspondingMesh.GetNode(selfID.String()) + + // Mesh has been removed + if self == nil { + return fmt.Errorf("mesh %s does not exist", correspondingMesh.GetMeshId()) + } correspondingMesh.Prune() - if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { - logging.Log.WriteInfof("No changes for %s", meshId) + if correspondingMesh.HasChanges() { + logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId()) + } + + if self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 { + logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId()) // If not synchronised in certain pull from random neighbour - if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 { - return s.Pull(meshId) + if uint64(time.Now().Unix())-s.lastSync[correspondingMesh.GetMeshId()] > 20 { + return s.Pull(self, correspondingMesh) } return nil @@ -87,14 +98,14 @@ func (s *SyncerImpl) Sync(meshId string) error { // Do this synchronously to conserve bandwidth for _, node := range gossipNodes { - correspondingPeer := s.manager.GetNode(meshId, node) + correspondingPeer := s.manager.GetNode(correspondingMesh.GetMeshId(), node) if correspondingPeer == nil { logging.Log.WriteErrorf("node %s does not exist", node) continue } - err := s.requester.SyncMesh(meshId, correspondingPeer) + err := s.requester.SyncMesh(correspondingMesh.GetMeshId(), correspondingPeer) if err == nil || err == io.EOF { succeeded = true @@ -116,36 +127,18 @@ func (s *SyncerImpl) Sync(meshId string) error { s.infectionCount++ } - s.manager.GetMesh(meshId).SaveChanges() - s.lastSync[meshId] = uint64(time.Now().Unix()) - - logging.Log.WriteInfof("UPDATING WG CONF") - err := s.manager.ApplyConfig() - - if err != nil { - logging.Log.WriteInfof("Failed to update config %w", err) - } + correspondingMesh.SaveChanges() + s.lastSync[correspondingMesh.GetMeshId()] = uint64(time.Now().Unix()) return nil } // Pull one node in the cluster, if there has not been message dissemination // in a certain period of time pull a random node within the cluster -func (s *SyncerImpl) Pull(meshId string) error { - mesh := s.manager.GetMesh(meshId) - self, err := s.manager.GetSelf(meshId) - - if err != nil { - return err - } - +func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error { + peers := mesh.GetPeers() pubKey, _ := self.GetPublicKey() - if mesh == nil { - return errors.New("mesh is nil, invalid operation") - } - - peers := mesh.GetPeers() neighbours := s.cluster.GetNeighbours(peers, pubKey.String()) neighbour := lib.RandomSubsetOfLength(neighbours, 1) @@ -162,10 +155,10 @@ func (s *SyncerImpl) Pull(meshId string) error { return fmt.Errorf("node %s does not exist in the mesh", neighbour[0]) } - err = s.requester.SyncMesh(meshId, pullNode) + err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode) if err == nil || err == io.EOF { - s.lastSync[meshId] = uint64(time.Now().Unix()) + s.lastSync[mesh.GetMeshId()] = uint64(time.Now().Unix()) } else { return err } @@ -176,14 +169,20 @@ func (s *SyncerImpl) Pull(meshId string) error { // SyncMeshes: Sync all meshes func (s *SyncerImpl) SyncMeshes() error { - for meshId := range s.manager.GetMeshes() { - err := s.Sync(meshId) + for _, mesh := range s.manager.GetMeshes() { + err := s.Sync(mesh) if err != nil { logging.Log.WriteErrorf(err.Error()) } } + logging.Log.WriteInfof("updating the WireGuard configuration") + err := s.manager.ApplyConfig() + + if err != nil { + logging.Log.WriteInfof("failed to update config %w", err) + } return nil }