Bugfix. Fixed issue where consistent hashing was not working.
This commit is contained in:
Tim Beatham 2023-11-28 14:42:09 +00:00
parent 1fae0a6c2c
commit 32e7e4c7df
11 changed files with 180 additions and 81 deletions

View File

@ -45,21 +45,26 @@ func main() {
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf, Conf: conf,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester = sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer) timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -79,7 +84,6 @@ func main() {
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run() go timestampScheduler.Run()
go pruneScheduler.Run() go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("Closing resources")

View File

@ -4,7 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"sync"
"time" "time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
@ -18,6 +20,7 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct { type CrdtMeshManager struct {
lock sync.RWMutex
MeshId string MeshId string
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
@ -39,10 +42,13 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt.Services = make(map[string]string) crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.lock.Lock()
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
c.lock.Unlock()
} }
func (c *CrdtMeshManager) isPeer(nodeId string) bool { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
c.lock.RLock()
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap { if err != nil || node.Kind() != automerge.KindMap {
@ -50,6 +56,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
} }
nodeType, err := node.Map().Get("type") nodeType, err := node.Map().Get("type")
c.lock.RUnlock()
if err != nil || nodeType.Kind() != automerge.KindStr { if err != nil || nodeType.Kind() != automerge.KindStr {
return false return false
@ -61,6 +68,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
// isAlive: checks that the node's configuration has been updated // isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time // since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool { func (c *CrdtMeshManager) isAlive(nodeId string) bool {
c.lock.RLock()
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap { if err != nil || node.Kind() != automerge.KindMap {
@ -68,6 +76,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
} }
timestamp, err := node.Map().Get("timestamp") timestamp, err := node.Map().Get("timestamp")
c.lock.RUnlock()
if err != nil || timestamp.Kind() != automerge.KindInt64 { if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false return false
@ -78,7 +87,9 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
} }
func (c *CrdtMeshManager) GetPeers() []string { func (c *CrdtMeshManager) GetPeers() []string {
c.lock.RLock()
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
c.lock.RUnlock()
keys = lib.Filter(keys, func(publicKey string) bool { keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey) return c.isPeer(publicKey) && c.isAlive(publicKey)
@ -97,7 +108,9 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
if c.cache == nil || len(changes) > 0 { if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash c.lastCacheHash = c.LastHash
c.lock.RLock()
cache, err := automerge.As[*MeshCrdt](c.doc.Root()) cache, err := automerge.As[*MeshCrdt](c.doc.Root())
c.lock.RUnlock()
if err != nil { if err != nil {
return nil, err return nil, err
@ -157,6 +170,7 @@ func (m *CrdtMeshManager) NodeExists(key string) bool {
} }
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap { if node.Kind() != automerge.KindMap {
@ -168,6 +182,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
} }
meshNode, err := automerge.As[*MeshNodeCrdt](node) meshNode, err := automerge.As[*MeshNodeCrdt](node)
m.lock.RUnlock()
if err != nil { if err != nil {
return nil, err return nil, err
@ -213,7 +228,9 @@ func (m *CrdtMeshManager) SaveChanges() {
} }
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
@ -223,7 +240,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return errors.New("node is not a map") return errors.New("node is not a map")
} }
m.lock.Lock()
err = node.Map().Set("timestamp", time.Now().Unix()) err = node.Map().Set("timestamp", time.Now().Unix())
m.lock.Unlock()
if err == nil { if err == nil {
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId) logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
@ -233,7 +252,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
} }
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
@ -243,7 +264,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return fmt.Errorf("%s does not exist", nodeId) return fmt.Errorf("%s does not exist", nodeId)
} }
m.lock.Lock()
err = node.Map().Set("description", description) err = node.Map().Set("description", description)
m.lock.Unlock()
if err == nil { if err == nil {
logging.Log.WriteInfof("Description Updated for %s", nodeId) logging.Log.WriteInfof("Description Updated for %s", nodeId)
@ -253,7 +276,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
} }
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
@ -263,7 +288,9 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
return fmt.Errorf("%s does not exist", nodeId) return fmt.Errorf("%s does not exist", nodeId)
} }
m.lock.Lock()
err = node.Map().Set("alias", alias) err = node.Map().Set("alias", alias)
m.lock.Unlock()
if err == nil { if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias) logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
@ -273,13 +300,17 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
} }
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
m.lock.RUnlock()
if err != nil || node.Kind() != automerge.KindMap { if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId) return fmt.Errorf("AddService: node %s does not exist", nodeId)
} }
m.lock.RLock()
service, err := node.Map().Get("services") service, err := node.Map().Get("services")
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
@ -289,10 +320,14 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return fmt.Errorf("AddService: services property does not exist in node") return fmt.Errorf("AddService: services property does not exist in node")
} }
return service.Map().Set(key, value) m.lock.Lock()
err = service.Map().Set(key, value)
m.lock.Unlock()
return err
} }
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
m.lock.RLock()
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap { if err != nil || node.Kind() != automerge.KindMap {
@ -308,8 +343,11 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
if service.Kind() != automerge.KindMap { if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist") return fmt.Errorf("services property does not exist")
} }
m.lock.RUnlock()
m.lock.Lock()
err = service.Map().Delete(key) err = service.Map().Delete(key)
m.lock.Unlock()
if err != nil { if err != nil {
return fmt.Errorf("service %s does not exist", key) return fmt.Errorf("service %s does not exist", key)
@ -320,6 +358,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -332,16 +371,41 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
} }
routeMap, err := nodeVal.Map().Get("routes") routeMap, err := nodeVal.Map().Get("routes")
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
} }
for _, route := range routes { for _, route := range routes {
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
if prevRoute.Kind() == automerge.KindVoid && err != nil {
path, err := prevRoute.Map().Get("path")
if err != nil {
return err
}
if path.Kind() != automerge.KindList {
return fmt.Errorf("path is not a list")
}
pathStr, err := automerge.As[[]string](path)
if err != nil {
return err
}
slices.Equal(route.GetPath(), pathStr)
}
m.lock.Lock()
err = routeMap.Map().Set(route.GetDestination().String(), Route{ err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(), Destination: route.GetDestination().String(),
Path: route.GetPath(), Path: route.GetPath(),
}) })
m.lock.Unlock()
if err != nil { if err != nil {
return err return err
@ -351,6 +415,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
} }
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil { if err != nil {
@ -372,6 +437,7 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
} }
routes, err := automerge.As[map[string]Route](routeMap) routes, err := automerge.As[map[string]Route](routeMap)
m.lock.RUnlock()
return lib.MapValues(routes), err return lib.MapValues(routes), err
} }
@ -385,10 +451,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
routes := make(map[string]mesh.Route) routes := make(map[string]mesh.Route)
// Add routes that the node directly has
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route routes[route.GetDestination().String()] = route
} }
// Work out the other routes in the mesh
for _, node := range m.GetPeers() { for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node) nodeRoutes, err := m.getRoutes(node)
@ -399,6 +467,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
for _, route := range nodeRoutes { for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()] otherRoute, ok := routes[route.GetDestination().String()]
hopCount := route.GetHopCount()
if node != targetNode {
hopCount += 1
}
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() { if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{ routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(), Destination: route.GetDestination().String(),
@ -411,8 +485,16 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
return routes, nil return routes, nil
} }
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
m.lock.Lock()
err := m.doc.Path("nodes").Map().Delete(nodeId)
m.lock.Unlock()
return err
}
// DeleteRoutes deletes the specified routes // DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
m.lock.RLock()
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil { if err != nil {
@ -424,14 +506,17 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
} }
routeMap, err := nodeVal.Map().Get("routes") routeMap, err := nodeVal.Map().Get("routes")
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
} }
m.lock.Lock()
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Delete(route) err = routeMap.Map().Delete(route)
} }
m.lock.Unlock()
return err return err
} }
@ -441,6 +526,7 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
} }
func (m *CrdtMeshManager) Prune(pruneTime int) error { func (m *CrdtMeshManager) Prune(pruneTime int) error {
m.lock.RLock()
nodes, err := m.doc.Path("nodes").Get() nodes, err := m.doc.Path("nodes").Get()
if err != nil { if err != nil {
@ -452,6 +538,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
} }
values, err := nodes.Map().Values() values, err := nodes.Map().Values()
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err
@ -466,7 +553,9 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodeMap := node.Map() nodeMap := node.Map()
m.lock.RLock()
timeStamp, err := nodeMap.Get("timestamp") timeStamp, err := nodeMap.Get("timestamp")
m.lock.RUnlock()
if err != nil { if err != nil {
return err return err

View File

@ -32,7 +32,6 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
func (a *AutomergeSync) Complete() { func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed") logging.Log.WriteInfof("Sync Completed")
a.manager.SaveChanges()
} }
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {

View File

@ -21,6 +21,7 @@ type NewCtrlServerParams struct {
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
Querier query.Querier Querier query.Querier
OnDelete func(mesh.MeshProvider)
} }
// Create a new instance of the MeshCtrlServer or error if the // Create a new instance of the MeshCtrlServer or error if the
@ -46,6 +47,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IPAllocator: ipAllocator, IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator, InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer, ConfigApplyer: configApplyer,
OnDelete: params.OnDelete,
} }
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)

View File

@ -9,6 +9,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route" "github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -32,10 +33,6 @@ type routeNode struct {
route Route route Route
} }
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet, peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
@ -63,9 +60,10 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()] bestRoutes := routes[route.GetDestination().String()]
var pickedRoute routeNode
if len(bestRoutes) == 1 { if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination()) pickedRoute = bestRoutes[0]
} else if len(bestRoutes) > 1 { } else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int { keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey() pubKey, _ := mn.GetPublicKey()
@ -77,11 +75,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
} }
// Else there is more than one candidate so consistently hash // Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc) pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
}
if pickedRoute.gateway == pubKey.String() { if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination()) allowedips = append(allowedips, *pickedRoute.route.GetDestination())
}
} }
} }
@ -101,6 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
ReplaceAllowedIPs: true,
} }
return &peerConfig, nil return &peerConfig, nil
@ -122,14 +121,9 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
for _, node := range mesh.GetNodes() { for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
for _, route := range meshRoutes { for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix == nil || route == nil || route.GetDestination() == nil {
return false
}
return prefix.Contains(route.GetDestination().IP) return prefix.Contains(route.GetDestination().IP)
}) { }) {
continue continue
@ -150,6 +144,8 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() { } else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() { } else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
routes[destination] = append(otherRoute, rn) routes[destination] = append(otherRoute, rn)
} }
} }
@ -218,7 +214,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) { if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{ installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP, Gateway: n.GetWgHost().IP,
Destination: route, Destination: route,
@ -240,13 +235,13 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
if err != nil { if err != nil {
return err return err
} }
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
} }
func (m *WgMeshConfigApplyer) ApplyConfig() error { func (m *WgMeshConfigApplyer) ApplyConfig() error {
@ -275,7 +270,6 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
} }
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 0), Peers: make([]wgtypes.PeerConfig, 0),
}) })

View File

@ -3,6 +3,7 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
@ -18,7 +19,7 @@ type MeshManager interface {
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
GetPublicKey(meshId string) (*wgtypes.Key, error) GetPublicKey() *wgtypes.Key
AddSelf(params *AddSelfParams) error AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
@ -38,6 +39,7 @@ type MeshManager interface {
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
lock sync.RWMutex
Meshes map[string]MeshProvider Meshes map[string]MeshProvider
RouteManager RouteManager RouteManager RouteManager
Client *wgctrl.Client Client *wgctrl.Client
@ -52,6 +54,7 @@ type MeshManagerImpl struct {
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor Monitor MeshMonitor
OnDelete func(MeshProvider)
} }
// GetRouteManager implements MeshManager. // GetRouteManager implements MeshManager.
@ -149,7 +152,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
m.lock.Lock()
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
m.lock.Unlock()
return meshId, nil return meshId, nil
} }
@ -190,7 +195,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err return err
} }
m.lock.Lock()
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
m.lock.Unlock()
return nil return nil
} }
@ -206,25 +213,14 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
} }
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
if s.conf.StubWg { if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen) zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil return (*wgtypes.Key)(zeroedKey)
} }
mesh, ok := s.Meshes[meshId] key := s.HostParameters.PrivateKey.PublicKey()
return &key
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
} }
type AddSelfParams struct { type AddSelfParams struct {
@ -289,14 +285,29 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
// LeaveMesh leaves the mesh network // LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error { func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh, exists := s.Meshes[meshId] mesh := s.GetMesh(meshId)
if !exists { if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
var err error var err error
s.RouteManager.RemoveRoutes(meshId)
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
if s.OnDelete != nil {
s.OnDelete(mesh)
}
s.lock.Lock()
delete(s.Meshes, meshId)
s.lock.Unlock()
if !s.conf.StubWg { if !s.conf.StubWg {
device, err := mesh.GetDevice() device, err := mesh.GetDevice()
@ -311,8 +322,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
} }
} }
err = s.RouteManager.RemoveRoutes(meshId)
delete(s.Meshes, meshId)
return err return err
} }
@ -348,7 +357,8 @@ func (s *MeshManagerImpl) ApplyConfig() error {
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description) err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
@ -363,7 +373,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
@ -377,7 +388,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { meshes := s.GetMeshes()
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey()) err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
@ -395,7 +407,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
} }
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes meshes := make(map[string]MeshProvider)
s.lock.RLock()
for id, mesh := range s.Meshes {
meshes[id] = mesh
}
s.lock.RUnlock()
return meshes
} }
// Close the mesh manager // Close the mesh manager
@ -432,6 +453,7 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer ConfigApplyer MeshConfigApplyer
RouteManager RouteManager RouteManager RouteManager
OnDelete func(MeshProvider)
} }
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
@ -466,5 +488,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
aliasManager := NewAliasManager() aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases) m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases) m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete
return m return m
} }

View File

@ -81,6 +81,11 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented")
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) { func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil return nil, nil
} }
@ -287,9 +292,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
} }
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
return &key, nil return &key
} }
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error { func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {

View File

@ -138,6 +138,8 @@ type MeshProvider interface {
GetPeers() []string GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen // GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error) GetRoutes(targetNode string) (map[string]Route, error)
// RemoveNode(): remove the node from the mesh
RemoveNode(nodeId string) error
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node

View File

@ -36,28 +36,17 @@ func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF") logging.Log.WriteInfof("UPDATING WG CONF")
if s.manager.HasChanges(meshId) { s.manager.GetRouteManager().UpdateRoutes()
err := s.manager.ApplyConfig() err := s.manager.ApplyConfig()
if err != nil { if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err) logging.Log.WriteInfof("Failed to update config %w", err)
} }
}
publicKey := s.manager.GetPublicKey()
nodeNames := s.manager.GetMesh(meshId).GetPeers() nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId) neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
if err != nil {
return err
}
selfPublickey, err := self.GetPublicKey()
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset { for _, node := range randomSubset {
@ -68,7 +57,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String()) interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
} }
@ -102,6 +91,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Check if any changes have occurred and trigger callbacks // Check if any changes have occurred and trigger callbacks
// if changes have occurred. // if changes have occurred.
// return s.manager.GetMonitor().Trigger() // return s.manager.GetMonitor().Trigger()
s.manager.GetMesh(meshId).SaveChanges()
return nil return nil
} }

View File

@ -12,7 +12,6 @@ func syncFunction(syncer Syncer) lib.TimerFunc {
} }
} }
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer { func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
} }

View File

@ -12,11 +12,3 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
} }
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}