Bugfix. Fixed issue where consistent hashing was not working.
This commit is contained in:
Tim Beatham 2023-11-28 14:42:09 +00:00
parent 1fae0a6c2c
commit 32e7e4c7df
11 changed files with 180 additions and 81 deletions

View File

@ -45,21 +45,26 @@ func main() {
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf,
CtrlProvider: &robinRpc,
SyncProvider: &syncProvider,
Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
}
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -79,7 +84,6 @@ func main() {
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")

View File

@ -4,7 +4,9 @@ import (
"errors"
"fmt"
"net"
"slices"
"strings"
"sync"
"time"
"github.com/automerge/automerge-go"
@ -18,6 +20,7 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
lock sync.RWMutex
MeshId string
IfName string
Client *wgctrl.Client
@ -39,10 +42,13 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.lock.Lock()
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
c.lock.Unlock()
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
c.lock.RLock()
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
@ -50,6 +56,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
}
nodeType, err := node.Map().Get("type")
c.lock.RUnlock()
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
@ -61,6 +68,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
c.lock.RLock()
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
@ -68,6 +76,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
}
timestamp, err := node.Map().Get("timestamp")
c.lock.RUnlock()
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
@ -78,7 +87,9 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
}
func (c *CrdtMeshManager) GetPeers() []string {
c.lock.RLock()
keys, _ := c.doc.Path("nodes").Map().Keys()
c.lock.RUnlock()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
@ -97,7 +108,9 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
c.lock.RLock()
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
c.lock.RUnlock()
if err != nil {
return nil, err
@ -157,6 +170,7 @@ func (m *CrdtMeshManager) NodeExists(key string) bool {
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
@ -168,6 +182,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
}
meshNode, err := automerge.As[*MeshNodeCrdt](node)
m.lock.RUnlock()
if err != nil {
return nil, err
@ -213,7 +228,9 @@ func (m *CrdtMeshManager) SaveChanges() {
}
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil {
return err
@ -223,7 +240,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return errors.New("node is not a map")
}
m.lock.Lock()
err = node.Map().Set("timestamp", time.Now().Unix())
m.lock.Unlock()
if err == nil {
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
@ -233,7 +252,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
}
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil {
return err
@ -243,7 +264,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return fmt.Errorf("%s does not exist", nodeId)
}
m.lock.Lock()
err = node.Map().Set("description", description)
m.lock.Unlock()
if err == nil {
logging.Log.WriteInfof("Description Updated for %s", nodeId)
@ -253,7 +276,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil {
return err
@ -263,7 +288,9 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
return fmt.Errorf("%s does not exist", nodeId)
}
m.lock.Lock()
err = node.Map().Set("alias", alias)
m.lock.Unlock()
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
@ -273,13 +300,17 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
m.lock.RLock()
service, err := node.Map().Get("services")
m.lock.RUnlock()
if err != nil {
return err
@ -289,10 +320,14 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
m.lock.Lock()
err = service.Map().Set(key, value)
m.lock.Unlock()
return err
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
@ -308,8 +343,11 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
m.lock.RUnlock()
m.lock.Lock()
err = service.Map().Delete(key)
m.lock.Unlock()
if err != nil {
return fmt.Errorf("service %s does not exist", key)
@ -320,6 +358,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -332,16 +371,41 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
}
routeMap, err := nodeVal.Map().Get("routes")
m.lock.RUnlock()
if err != nil {
return err
}
for _, route := range routes {
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
m.lock.Lock()
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
m.lock.Unlock()
if err != nil {
return err
@ -351,6 +415,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
}
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
@ -372,6 +437,7 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
}
routes, err := automerge.As[map[string]Route](routeMap)
m.lock.RUnlock()
return lib.MapValues(routes), err
}
@ -385,10 +451,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
// Work out the other routes in the mesh
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
@ -399,6 +467,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
@ -411,8 +485,16 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
return routes, nil
}
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
m.lock.Lock()
err := m.doc.Path("nodes").Map().Delete(nodeId)
m.lock.Unlock()
return err
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
@ -424,14 +506,17 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
}
routeMap, err := nodeVal.Map().Get("routes")
m.lock.RUnlock()
if err != nil {
return err
}
m.lock.Lock()
for _, route := range routes {
err = routeMap.Map().Delete(route)
}
m.lock.Unlock()
return err
}
@ -441,6 +526,7 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
m.lock.RLock()
nodes, err := m.doc.Path("nodes").Get()
if err != nil {
@ -452,6 +538,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
}
values, err := nodes.Map().Values()
m.lock.RUnlock()
if err != nil {
return err
@ -466,7 +553,9 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodeMap := node.Map()
m.lock.RLock()
timeStamp, err := nodeMap.Get("timestamp")
m.lock.RUnlock()
if err != nil {
return err

View File

@ -32,7 +32,6 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed")
a.manager.SaveChanges()
}
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {

View File

@ -21,6 +21,7 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer
Querier query.Querier
OnDelete func(mesh.MeshProvider)
}
// Create a new instance of the MeshCtrlServer or error if the
@ -46,6 +47,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)

View File

@ -9,6 +9,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -32,10 +33,6 @@ type routeNode struct {
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
@ -63,9 +60,10 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
for _, route := range node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
@ -77,11 +75,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
}
@ -101,6 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
}
return &peerConfig, nil
@ -122,14 +121,9 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
for _, route := range meshRoutes {
for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix == nil || route == nil || route.GetDestination() == nil {
return false
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
@ -150,6 +144,8 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
routes[destination] = append(otherRoute, rn)
}
}
@ -218,7 +214,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
@ -240,13 +235,13 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
@ -275,8 +270,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
}
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 0),
Peers: make([]wgtypes.PeerConfig, 0),
})
return nil

View File

@ -3,6 +3,7 @@ package mesh
import (
"errors"
"fmt"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
@ -18,7 +19,7 @@ type MeshManager interface {
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
GetPublicKey(meshId string) (*wgtypes.Key, error)
GetPublicKey() *wgtypes.Key
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
@ -38,6 +39,7 @@ type MeshManager interface {
}
type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
@ -52,6 +54,7 @@ type MeshManagerImpl struct {
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
OnDelete func(MeshProvider)
}
// GetRouteManager implements MeshManager.
@ -149,7 +152,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.lock.Lock()
m.Meshes[meshId] = nodeManager
m.lock.Unlock()
return meshId, nil
}
@ -190,7 +195,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err
}
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider
m.lock.Unlock()
return nil
}
@ -206,25 +213,14 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
return (*wgtypes.Key)(zeroedKey)
}
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
key := s.HostParameters.PrivateKey.PublicKey()
return &key
}
type AddSelfParams struct {
@ -289,14 +285,29 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
// LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId]
mesh := s.GetMesh(meshId)
if !exists {
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
var err error
s.RouteManager.RemoveRoutes(meshId)
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
if s.OnDelete != nil {
s.OnDelete(mesh)
}
s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
if !s.conf.StubWg {
device, err := mesh.GetDevice()
@ -311,8 +322,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
}
}
err = s.RouteManager.RemoveRoutes(meshId)
delete(s.Meshes, meshId)
return err
}
@ -348,7 +357,8 @@ func (s *MeshManagerImpl) ApplyConfig() error {
}
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
@ -363,7 +373,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
@ -377,7 +388,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
@ -395,7 +407,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
}
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
meshes := make(map[string]MeshProvider)
s.lock.RLock()
for id, mesh := range s.Meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
return meshes
}
// Close the mesh manager
@ -432,6 +453,7 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
OnDelete func(MeshProvider)
}
// Creates a new instance of a mesh manager with the given parameters
@ -466,5 +488,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m
}

View File

@ -81,6 +81,11 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
@ -287,9 +292,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
key, _ := wgtypes.GenerateKey()
return &key, nil
return &key
}
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {

View File

@ -138,6 +138,8 @@ type MeshProvider interface {
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
}
// HostParameters contains the IDs of a node

View File

@ -36,28 +36,17 @@ func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
s.manager.GetRouteManager().UpdateRoutes()
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
publicKey := s.manager.GetPublicKey()
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
selfPublickey, err := self.GetPublicKey()
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
@ -68,7 +57,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
randomSubset = append(randomSubset, interCluster)
}
@ -102,6 +91,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
s.manager.GetMesh(meshId).SaveChanges()
return nil
}

View File

@ -12,7 +12,6 @@ func syncFunction(syncer Syncer) lib.TimerFunc {
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -12,11 +12,3 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}