mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-02-03 11:09:17 +01:00
main
- Fixing issue with nil pointer de-reference due to bad design of mesh manager. - Going forward all references to GetSelf should be depracated. It introduces a race condition when leaving a mesh network
This commit is contained in:
parent
b0893a0b8e
commit
2e6aed6f93
@ -119,7 +119,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
|
||||
|
||||
// getRoutes: finds the routes with the least hop distance. If more than one route exists
|
||||
// consistently hash to evenly spread the distribution of traffic
|
||||
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
|
||||
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
|
||||
mesh, _ := meshProvider.GetMesh()
|
||||
routes := make(map[string][]routeNode)
|
||||
|
||||
@ -158,17 +158,19 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
// Client's only acessible by another peer
|
||||
if node.GetType() == conf.CLIENT_ROLE {
|
||||
peer := m.getCorrespondingPeer(peers, node)
|
||||
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
|
||||
self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the node isn't the self use that peer as the gateway
|
||||
if !NodeEquals(peer, self) {
|
||||
peerPub, _ := peer.GetPublicKey()
|
||||
rn.gateway = peerPub.String()
|
||||
rn.route = &RouteStub{
|
||||
Destination: rn.route.GetDestination(),
|
||||
HopCount: rn.route.GetHopCount() + 1,
|
||||
// Append the path to this peer
|
||||
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
|
||||
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,7 +187,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// getCorrespondignPeer: gets the peer corresponding to the client
|
||||
@ -219,7 +221,6 @@ type GetConfigParams struct {
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
|
||||
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||
ula := &ip.ULABuilder{}
|
||||
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
|
||||
|
||||
@ -235,6 +236,8 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
|
||||
})
|
||||
routes = append(routes, *meshNet)
|
||||
|
||||
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -302,7 +305,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
|
||||
peerToClients := make(map[string][]net.IPNet)
|
||||
installedRoutes := make([]lib.Route, 0)
|
||||
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
||||
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -393,7 +396,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
|
||||
return mn.GetType() == conf.CLIENT_ROLE
|
||||
})
|
||||
|
||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
||||
self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -432,11 +435,15 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
|
||||
func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
|
||||
allRoutes := make(map[string][]routeNode)
|
||||
|
||||
for _, mesh := range m.meshManager.GetMeshes() {
|
||||
routes := m.getRoutes(mesh)
|
||||
routes, err := m.getRoutes(mesh)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for destination, route := range routes {
|
||||
_, ok := allRoutes[destination]
|
||||
@ -454,11 +461,15 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
|
||||
}
|
||||
}
|
||||
|
||||
return allRoutes
|
||||
return allRoutes, nil
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||
allRoutes := m.getAllRoutes()
|
||||
allRoutes, err := m.getAllRoutes()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mesh := range m.meshManager.GetMeshes() {
|
||||
err := m.updateWgConf(mesh, allRoutes)
|
||||
|
@ -24,7 +24,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
continue
|
||||
}
|
||||
|
||||
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
||||
self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -90,11 +90,20 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
|
||||
// Calculate the set different of each, working out routes to remove and to keep.
|
||||
for meshId, meshRoutes := range routes {
|
||||
mesh := r.meshManager.GetMesh(meshId)
|
||||
self, _ := r.meshManager.GetSelf(meshId)
|
||||
toRemove := make([]Route, 0)
|
||||
mesh := meshes[meshId]
|
||||
|
||||
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
|
||||
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toRemove := make([]Route, 0)
|
||||
prevRoutes, err := mesh.GetRoutes(NodeID(self))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range prevRoutes {
|
||||
if !lib.Contains(meshRoutes, func(r Route) bool {
|
||||
|
@ -81,6 +81,10 @@ func NodeEquals(node1, node2 MeshNode) bool {
|
||||
key1, _ := node1.GetPublicKey()
|
||||
key2, _ := node2.GetPublicKey()
|
||||
|
||||
if node1 == nil || node2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return key1.String() == key2.String()
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
@ -16,7 +15,7 @@ import (
|
||||
|
||||
// Syncer: picks random nodes from the meshs
|
||||
type Syncer interface {
|
||||
Sync(meshId string) error
|
||||
Sync(theMesh mesh.MeshProvider) error
|
||||
SyncMeshes() error
|
||||
}
|
||||
|
||||
@ -30,21 +29,33 @@ type SyncerImpl struct {
|
||||
lastSync map[string]uint64
|
||||
}
|
||||
|
||||
// Sync: Sync random nodes
|
||||
func (s *SyncerImpl) Sync(meshId string) error {
|
||||
// Self can be nil if the node is removed
|
||||
self, _ := s.manager.GetSelf(meshId)
|
||||
// Sync: Sync with random nodes
|
||||
func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error {
|
||||
if correspondingMesh == nil {
|
||||
return fmt.Errorf("mesh provided was nil cannot sync nil mesh")
|
||||
}
|
||||
|
||||
correspondingMesh := s.manager.GetMesh(meshId)
|
||||
// Self can be nil if the node is removed
|
||||
selfID := s.manager.GetPublicKey()
|
||||
self, _ := correspondingMesh.GetNode(selfID.String())
|
||||
|
||||
// Mesh has been removed
|
||||
if self == nil {
|
||||
return fmt.Errorf("mesh %s does not exist", correspondingMesh.GetMeshId())
|
||||
}
|
||||
|
||||
correspondingMesh.Prune()
|
||||
|
||||
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
||||
logging.Log.WriteInfof("No changes for %s", meshId)
|
||||
if correspondingMesh.HasChanges() {
|
||||
logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
|
||||
}
|
||||
|
||||
if self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
|
||||
logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
|
||||
|
||||
// If not synchronised in certain pull from random neighbour
|
||||
if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 {
|
||||
return s.Pull(meshId)
|
||||
if uint64(time.Now().Unix())-s.lastSync[correspondingMesh.GetMeshId()] > 20 {
|
||||
return s.Pull(self, correspondingMesh)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -87,14 +98,14 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
// Do this synchronously to conserve bandwidth
|
||||
for _, node := range gossipNodes {
|
||||
correspondingPeer := s.manager.GetNode(meshId, node)
|
||||
correspondingPeer := s.manager.GetNode(correspondingMesh.GetMeshId(), node)
|
||||
|
||||
if correspondingPeer == nil {
|
||||
logging.Log.WriteErrorf("node %s does not exist", node)
|
||||
continue
|
||||
}
|
||||
|
||||
err := s.requester.SyncMesh(meshId, correspondingPeer)
|
||||
err := s.requester.SyncMesh(correspondingMesh.GetMeshId(), correspondingPeer)
|
||||
|
||||
if err == nil || err == io.EOF {
|
||||
succeeded = true
|
||||
@ -116,36 +127,18 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
s.infectionCount++
|
||||
}
|
||||
|
||||
s.manager.GetMesh(meshId).SaveChanges()
|
||||
s.lastSync[meshId] = uint64(time.Now().Unix())
|
||||
|
||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||
err := s.manager.ApplyConfig()
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||
}
|
||||
correspondingMesh.SaveChanges()
|
||||
|
||||
s.lastSync[correspondingMesh.GetMeshId()] = uint64(time.Now().Unix())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pull one node in the cluster, if there has not been message dissemination
|
||||
// in a certain period of time pull a random node within the cluster
|
||||
func (s *SyncerImpl) Pull(meshId string) error {
|
||||
mesh := s.manager.GetMesh(meshId)
|
||||
self, err := s.manager.GetSelf(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error {
|
||||
peers := mesh.GetPeers()
|
||||
pubKey, _ := self.GetPublicKey()
|
||||
|
||||
if mesh == nil {
|
||||
return errors.New("mesh is nil, invalid operation")
|
||||
}
|
||||
|
||||
peers := mesh.GetPeers()
|
||||
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
|
||||
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
|
||||
|
||||
@ -162,10 +155,10 @@ func (s *SyncerImpl) Pull(meshId string) error {
|
||||
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
|
||||
}
|
||||
|
||||
err = s.requester.SyncMesh(meshId, pullNode)
|
||||
err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode)
|
||||
|
||||
if err == nil || err == io.EOF {
|
||||
s.lastSync[meshId] = uint64(time.Now().Unix())
|
||||
s.lastSync[mesh.GetMeshId()] = uint64(time.Now().Unix())
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
@ -176,14 +169,20 @@ func (s *SyncerImpl) Pull(meshId string) error {
|
||||
|
||||
// SyncMeshes: Sync all meshes
|
||||
func (s *SyncerImpl) SyncMeshes() error {
|
||||
for meshId := range s.manager.GetMeshes() {
|
||||
err := s.Sync(meshId)
|
||||
for _, mesh := range s.manager.GetMeshes() {
|
||||
err := s.Sync(mesh)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
logging.Log.WriteInfof("updating the WireGuard configuration")
|
||||
err := s.manager.ApplyConfig()
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof("failed to update config %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user