mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2024-12-04 21:50:49 +01:00
main
Bugfix. Fixed issue where consistent hashing was not working.
This commit is contained in:
parent
1fae0a6c2c
commit
32e7e4c7df
@ -45,21 +45,26 @@ func main() {
|
||||
var robinRpc robin.WgRpc
|
||||
var robinIpc robin.IpcHandler
|
||||
var syncProvider sync.SyncServiceImpl
|
||||
var syncRequester sync.SyncRequester
|
||||
var syncer sync.Syncer
|
||||
|
||||
ctrlServerParams := ctrlserver.NewCtrlServerParams{
|
||||
Conf: conf,
|
||||
CtrlProvider: &robinRpc,
|
||||
SyncProvider: &syncProvider,
|
||||
Client: client,
|
||||
OnDelete: func(mp mesh.MeshProvider) {
|
||||
syncer.SyncMeshes()
|
||||
},
|
||||
}
|
||||
|
||||
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
|
||||
syncProvider.Server = ctrlServer
|
||||
syncRequester := sync.NewSyncRequester(ctrlServer)
|
||||
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
|
||||
syncRequester = sync.NewSyncRequester(ctrlServer)
|
||||
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
|
||||
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
|
||||
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
|
||||
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
|
||||
routeScheduler := timer.NewRouteScheduler(ctrlServer)
|
||||
|
||||
robinIpcParams := robin.RobinIpcParams{
|
||||
CtrlServer: ctrlServer,
|
||||
@ -79,7 +84,6 @@ func main() {
|
||||
go syncScheduler.Run()
|
||||
go timestampScheduler.Run()
|
||||
go pruneScheduler.Run()
|
||||
go routeScheduler.Run()
|
||||
|
||||
closeResources := func() {
|
||||
logging.Log.WriteInfof("Closing resources")
|
||||
|
@ -4,7 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/automerge/automerge-go"
|
||||
@ -18,6 +20,7 @@ import (
|
||||
|
||||
// CrdtMeshManager manages nodes in the crdt mesh
|
||||
type CrdtMeshManager struct {
|
||||
lock sync.RWMutex
|
||||
MeshId string
|
||||
IfName string
|
||||
Client *wgctrl.Client
|
||||
@ -39,10 +42,13 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
|
||||
crdt.Services = make(map[string]string)
|
||||
crdt.Timestamp = time.Now().Unix()
|
||||
|
||||
c.lock.Lock()
|
||||
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
||||
c.lock.RLock()
|
||||
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil || node.Kind() != automerge.KindMap {
|
||||
@ -50,6 +56,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
||||
}
|
||||
|
||||
nodeType, err := node.Map().Get("type")
|
||||
c.lock.RUnlock()
|
||||
|
||||
if err != nil || nodeType.Kind() != automerge.KindStr {
|
||||
return false
|
||||
@ -61,6 +68,7 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
|
||||
// isAlive: checks that the node's configuration has been updated
|
||||
// since the rquired keep alive time
|
||||
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
||||
c.lock.RLock()
|
||||
node, err := c.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil || node.Kind() != automerge.KindMap {
|
||||
@ -68,6 +76,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
||||
}
|
||||
|
||||
timestamp, err := node.Map().Get("timestamp")
|
||||
c.lock.RUnlock()
|
||||
|
||||
if err != nil || timestamp.Kind() != automerge.KindInt64 {
|
||||
return false
|
||||
@ -78,7 +87,9 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
||||
}
|
||||
|
||||
func (c *CrdtMeshManager) GetPeers() []string {
|
||||
c.lock.RLock()
|
||||
keys, _ := c.doc.Path("nodes").Map().Keys()
|
||||
c.lock.RUnlock()
|
||||
|
||||
keys = lib.Filter(keys, func(publicKey string) bool {
|
||||
return c.isPeer(publicKey) && c.isAlive(publicKey)
|
||||
@ -97,7 +108,9 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||
|
||||
if c.cache == nil || len(changes) > 0 {
|
||||
c.lastCacheHash = c.LastHash
|
||||
c.lock.RLock()
|
||||
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
|
||||
c.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -157,6 +170,7 @@ func (m *CrdtMeshManager) NodeExists(key string) bool {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(endpoint)
|
||||
|
||||
if node.Kind() != automerge.KindMap {
|
||||
@ -168,6 +182,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
|
||||
}
|
||||
|
||||
meshNode, err := automerge.As[*MeshNodeCrdt](node)
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -213,7 +228,9 @@ func (m *CrdtMeshManager) SaveChanges() {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -223,7 +240,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||
return errors.New("node is not a map")
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
err = node.Map().Set("timestamp", time.Now().Unix())
|
||||
m.lock.Unlock()
|
||||
|
||||
if err == nil {
|
||||
logging.Log.WriteInfof("Timestamp Updated for %s", nodeId)
|
||||
@ -233,7 +252,9 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -243,7 +264,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
|
||||
return fmt.Errorf("%s does not exist", nodeId)
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
err = node.Map().Set("description", description)
|
||||
m.lock.Unlock()
|
||||
|
||||
if err == nil {
|
||||
logging.Log.WriteInfof("Description Updated for %s", nodeId)
|
||||
@ -253,7 +276,9 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -263,7 +288,9 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
||||
return fmt.Errorf("%s does not exist", nodeId)
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
err = node.Map().Set("alias", alias)
|
||||
m.lock.Unlock()
|
||||
|
||||
if err == nil {
|
||||
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
|
||||
@ -273,13 +300,17 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil || node.Kind() != automerge.KindMap {
|
||||
return fmt.Errorf("AddService: node %s does not exist", nodeId)
|
||||
}
|
||||
|
||||
m.lock.RLock()
|
||||
service, err := node.Map().Get("services")
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -289,10 +320,14 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
|
||||
return fmt.Errorf("AddService: services property does not exist in node")
|
||||
}
|
||||
|
||||
return service.Map().Set(key, value)
|
||||
m.lock.Lock()
|
||||
err = service.Map().Set(key, value)
|
||||
m.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
||||
m.lock.RLock()
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil || node.Kind() != automerge.KindMap {
|
||||
@ -308,8 +343,11 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
||||
if service.Kind() != automerge.KindMap {
|
||||
return fmt.Errorf("services property does not exist")
|
||||
}
|
||||
m.lock.RUnlock()
|
||||
|
||||
m.lock.Lock()
|
||||
err = service.Map().Delete(key)
|
||||
m.lock.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("service %s does not exist", key)
|
||||
@ -320,6 +358,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
|
||||
|
||||
// AddRoutes: adds routes to the specific nodeId
|
||||
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
||||
m.lock.RLock()
|
||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
logging.Log.WriteInfof("Adding route to %s", nodeId)
|
||||
|
||||
@ -332,16 +371,41 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
||||
}
|
||||
|
||||
routeMap, err := nodeVal.Map().Get("routes")
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
prevRoute, err := routeMap.Map().Get(route.GetDestination().String())
|
||||
|
||||
if prevRoute.Kind() == automerge.KindVoid && err != nil {
|
||||
path, err := prevRoute.Map().Get("path")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if path.Kind() != automerge.KindList {
|
||||
return fmt.Errorf("path is not a list")
|
||||
}
|
||||
|
||||
pathStr, err := automerge.As[[]string](path)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slices.Equal(route.GetPath(), pathStr)
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
err = routeMap.Map().Set(route.GetDestination().String(), Route{
|
||||
Destination: route.GetDestination().String(),
|
||||
Path: route.GetPath(),
|
||||
})
|
||||
m.lock.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -351,6 +415,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
||||
m.lock.RLock()
|
||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil {
|
||||
@ -372,6 +437,7 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
|
||||
}
|
||||
|
||||
routes, err := automerge.As[map[string]Route](routeMap)
|
||||
m.lock.RUnlock()
|
||||
|
||||
return lib.MapValues(routes), err
|
||||
}
|
||||
@ -385,10 +451,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
|
||||
|
||||
routes := make(map[string]mesh.Route)
|
||||
|
||||
// Add routes that the node directly has
|
||||
for _, route := range node.GetRoutes() {
|
||||
routes[route.GetDestination().String()] = route
|
||||
}
|
||||
|
||||
// Work out the other routes in the mesh
|
||||
for _, node := range m.GetPeers() {
|
||||
nodeRoutes, err := m.getRoutes(node)
|
||||
|
||||
@ -399,6 +467,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
|
||||
for _, route := range nodeRoutes {
|
||||
otherRoute, ok := routes[route.GetDestination().String()]
|
||||
|
||||
hopCount := route.GetHopCount()
|
||||
|
||||
if node != targetNode {
|
||||
hopCount += 1
|
||||
}
|
||||
|
||||
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
|
||||
routes[route.GetDestination().String()] = &Route{
|
||||
Destination: route.GetDestination().String(),
|
||||
@ -411,8 +485,16 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
||||
m.lock.Lock()
|
||||
err := m.doc.Path("nodes").Map().Delete(nodeId)
|
||||
m.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteRoutes deletes the specified routes
|
||||
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
||||
m.lock.RLock()
|
||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil {
|
||||
@ -424,14 +506,17 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
||||
}
|
||||
|
||||
routeMap, err := nodeVal.Map().Get("routes")
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
for _, route := range routes {
|
||||
err = routeMap.Map().Delete(route)
|
||||
}
|
||||
m.lock.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
@ -441,6 +526,7 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
||||
m.lock.RLock()
|
||||
nodes, err := m.doc.Path("nodes").Get()
|
||||
|
||||
if err != nil {
|
||||
@ -452,6 +538,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
||||
}
|
||||
|
||||
values, err := nodes.Map().Values()
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -466,7 +553,9 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
||||
|
||||
nodeMap := node.Map()
|
||||
|
||||
m.lock.RLock()
|
||||
timeStamp, err := nodeMap.Get("timestamp")
|
||||
m.lock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -32,7 +32,6 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
|
||||
|
||||
func (a *AutomergeSync) Complete() {
|
||||
logging.Log.WriteInfof("Sync Completed")
|
||||
a.manager.SaveChanges()
|
||||
}
|
||||
|
||||
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
|
||||
|
@ -21,6 +21,7 @@ type NewCtrlServerParams struct {
|
||||
CtrlProvider rpc.MeshCtrlServerServer
|
||||
SyncProvider rpc.SyncServiceServer
|
||||
Querier query.Querier
|
||||
OnDelete func(mesh.MeshProvider)
|
||||
}
|
||||
|
||||
// Create a new instance of the MeshCtrlServer or error if the
|
||||
@ -46,6 +47,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||
IPAllocator: ipAllocator,
|
||||
InterfaceManipulator: interfaceManipulator,
|
||||
ConfigApplyer: configApplyer,
|
||||
OnDelete: params.OnDelete,
|
||||
}
|
||||
|
||||
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@ -32,10 +33,6 @@ type routeNode struct {
|
||||
route Route
|
||||
}
|
||||
|
||||
func (r *routeNode) equals(route2 *routeNode) bool {
|
||||
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
|
||||
peerToClients map[string][]net.IPNet,
|
||||
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
||||
@ -63,9 +60,10 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
||||
|
||||
for _, route := range node.GetRoutes() {
|
||||
bestRoutes := routes[route.GetDestination().String()]
|
||||
var pickedRoute routeNode
|
||||
|
||||
if len(bestRoutes) == 1 {
|
||||
allowedips = append(allowedips, *route.GetDestination())
|
||||
pickedRoute = bestRoutes[0]
|
||||
} else if len(bestRoutes) > 1 {
|
||||
keyFunc := func(mn MeshNode) int {
|
||||
pubKey, _ := mn.GetPublicKey()
|
||||
@ -77,11 +75,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
||||
}
|
||||
|
||||
// Else there is more than one candidate so consistently hash
|
||||
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
|
||||
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
|
||||
}
|
||||
|
||||
if pickedRoute.gateway == pubKey.String() {
|
||||
allowedips = append(allowedips, *route.GetDestination())
|
||||
}
|
||||
if pickedRoute.gateway == pubKey.String() {
|
||||
allowedips = append(allowedips, *pickedRoute.route.GetDestination())
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,6 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
||||
Endpoint: endpoint,
|
||||
AllowedIPs: allowedips,
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
|
||||
return &peerConfig, nil
|
||||
@ -122,14 +121,9 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
|
||||
for _, node := range mesh.GetNodes() {
|
||||
pubKey, _ := node.GetPublicKey()
|
||||
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
|
||||
|
||||
for _, route := range meshRoutes {
|
||||
for _, route := range node.GetRoutes() {
|
||||
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
||||
if prefix == nil || route == nil || route.GetDestination() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return prefix.Contains(route.GetDestination().IP)
|
||||
}) {
|
||||
continue
|
||||
@ -150,6 +144,8 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
||||
otherRoute[0] = rn
|
||||
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
||||
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
|
||||
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
|
||||
routes[destination] = append(otherRoute, rn)
|
||||
}
|
||||
}
|
||||
@ -218,7 +214,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||
|
||||
if !ipNet.Contains(route.IP) {
|
||||
|
||||
installedRoutes = append(installedRoutes, lib.Route{
|
||||
Gateway: n.GetWgHost().IP,
|
||||
Destination: route,
|
||||
@ -240,13 +235,13 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
|
||||
return m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||
@ -275,8 +270,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
|
||||
}
|
||||
|
||||
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
|
||||
ReplacePeers: true,
|
||||
Peers: make([]wgtypes.PeerConfig, 0),
|
||||
Peers: make([]wgtypes.PeerConfig, 0),
|
||||
})
|
||||
|
||||
return nil
|
||||
|
@ -3,6 +3,7 @@ package mesh
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||
@ -18,7 +19,7 @@ type MeshManager interface {
|
||||
AddMesh(params *AddMeshParams) error
|
||||
HasChanges(meshid string) bool
|
||||
GetMesh(meshId string) MeshProvider
|
||||
GetPublicKey(meshId string) (*wgtypes.Key, error)
|
||||
GetPublicKey() *wgtypes.Key
|
||||
AddSelf(params *AddSelfParams) error
|
||||
LeaveMesh(meshId string) error
|
||||
GetSelf(meshId string) (MeshNode, error)
|
||||
@ -38,6 +39,7 @@ type MeshManager interface {
|
||||
}
|
||||
|
||||
type MeshManagerImpl struct {
|
||||
lock sync.RWMutex
|
||||
Meshes map[string]MeshProvider
|
||||
RouteManager RouteManager
|
||||
Client *wgctrl.Client
|
||||
@ -52,6 +54,7 @@ type MeshManagerImpl struct {
|
||||
ipAllocator ip.IPAllocator
|
||||
interfaceManipulator wg.WgInterfaceManipulator
|
||||
Monitor MeshMonitor
|
||||
OnDelete func(MeshProvider)
|
||||
}
|
||||
|
||||
// GetRouteManager implements MeshManager.
|
||||
@ -149,7 +152,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
||||
return "", fmt.Errorf("error creating mesh: %w", err)
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
m.Meshes[meshId] = nodeManager
|
||||
m.lock.Unlock()
|
||||
return meshId, nil
|
||||
}
|
||||
|
||||
@ -190,7 +195,9 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
m.Meshes[params.MeshId] = meshProvider
|
||||
m.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -206,25 +213,14 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
|
||||
}
|
||||
|
||||
// GetPublicKey: Gets the public key of the WireGuard mesh
|
||||
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
||||
func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
|
||||
if s.conf.StubWg {
|
||||
zeroedKey := make([]byte, wgtypes.KeyLen)
|
||||
return (*wgtypes.Key)(zeroedKey), nil
|
||||
return (*wgtypes.Key)(zeroedKey)
|
||||
}
|
||||
|
||||
mesh, ok := s.Meshes[meshId]
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("mesh does not exist")
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dev.PublicKey, nil
|
||||
key := s.HostParameters.PrivateKey.PublicKey()
|
||||
return &key
|
||||
}
|
||||
|
||||
type AddSelfParams struct {
|
||||
@ -289,14 +285,29 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
|
||||
|
||||
// LeaveMesh leaves the mesh network
|
||||
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
||||
mesh, exists := s.Meshes[meshId]
|
||||
mesh := s.GetMesh(meshId)
|
||||
|
||||
if !exists {
|
||||
if mesh == nil {
|
||||
return fmt.Errorf("mesh %s does not exist", meshId)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
s.RouteManager.RemoveRoutes(meshId)
|
||||
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.OnDelete != nil {
|
||||
s.OnDelete(mesh)
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
delete(s.Meshes, meshId)
|
||||
s.lock.Unlock()
|
||||
|
||||
if !s.conf.StubWg {
|
||||
device, err := mesh.GetDevice()
|
||||
|
||||
@ -311,8 +322,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
||||
}
|
||||
}
|
||||
|
||||
err = s.RouteManager.RemoveRoutes(meshId)
|
||||
delete(s.Meshes, meshId)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -348,7 +357,8 @@ func (s *MeshManagerImpl) ApplyConfig() error {
|
||||
}
|
||||
|
||||
func (s *MeshManagerImpl) SetDescription(description string) error {
|
||||
for _, mesh := range s.Meshes {
|
||||
meshes := s.GetMeshes()
|
||||
for _, mesh := range meshes {
|
||||
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
|
||||
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
|
||||
|
||||
@ -363,7 +373,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
|
||||
|
||||
// SetAlias implements MeshManager.
|
||||
func (s *MeshManagerImpl) SetAlias(alias string) error {
|
||||
for _, mesh := range s.Meshes {
|
||||
meshes := s.GetMeshes()
|
||||
for _, mesh := range meshes {
|
||||
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
|
||||
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
|
||||
|
||||
@ -377,7 +388,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
|
||||
|
||||
// UpdateTimeStamp updates the timestamp of this node in all meshes
|
||||
func (s *MeshManagerImpl) UpdateTimeStamp() error {
|
||||
for _, mesh := range s.Meshes {
|
||||
meshes := s.GetMeshes()
|
||||
for _, mesh := range meshes {
|
||||
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
|
||||
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
|
||||
|
||||
@ -395,7 +407,16 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
|
||||
}
|
||||
|
||||
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
|
||||
return s.Meshes
|
||||
meshes := make(map[string]MeshProvider)
|
||||
|
||||
s.lock.RLock()
|
||||
|
||||
for id, mesh := range s.Meshes {
|
||||
meshes[id] = mesh
|
||||
}
|
||||
|
||||
s.lock.RUnlock()
|
||||
return meshes
|
||||
}
|
||||
|
||||
// Close the mesh manager
|
||||
@ -432,6 +453,7 @@ type NewMeshManagerParams struct {
|
||||
InterfaceManipulator wg.WgInterfaceManipulator
|
||||
ConfigApplyer MeshConfigApplyer
|
||||
RouteManager RouteManager
|
||||
OnDelete func(MeshProvider)
|
||||
}
|
||||
|
||||
// Creates a new instance of a mesh manager with the given parameters
|
||||
@ -466,5 +488,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
|
||||
aliasManager := NewAliasManager()
|
||||
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
|
||||
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
|
||||
m.OnDelete = params.OnDelete
|
||||
return m
|
||||
}
|
||||
|
@ -81,6 +81,11 @@ type MeshProviderStub struct {
|
||||
snapshot *MeshSnapshotStub
|
||||
}
|
||||
|
||||
// RemoveNode implements MeshProvider.
|
||||
func (*MeshProviderStub) RemoveNode(nodeId string) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@ -287,9 +292,9 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
|
||||
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
|
||||
}
|
||||
|
||||
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
||||
func (m *MeshManagerStub) GetPublicKey() *wgtypes.Key {
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
return &key, nil
|
||||
return &key
|
||||
}
|
||||
|
||||
func (m *MeshManagerStub) AddSelf(params *AddSelfParams) error {
|
||||
|
@ -138,6 +138,8 @@ type MeshProvider interface {
|
||||
GetPeers() []string
|
||||
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
||||
GetRoutes(targetNode string) (map[string]Route, error)
|
||||
// RemoveNode(): remove the node from the mesh
|
||||
RemoveNode(nodeId string) error
|
||||
}
|
||||
|
||||
// HostParameters contains the IDs of a node
|
||||
|
@ -36,28 +36,17 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||
|
||||
if s.manager.HasChanges(meshId) {
|
||||
err := s.manager.ApplyConfig()
|
||||
s.manager.GetRouteManager().UpdateRoutes()
|
||||
err := s.manager.ApplyConfig()
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||
}
|
||||
|
||||
publicKey := s.manager.GetPublicKey()
|
||||
|
||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
||||
self, err := s.manager.GetSelf(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
selfPublickey, err := self.GetPublicKey()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
|
||||
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
||||
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
||||
|
||||
for _, node := range randomSubset {
|
||||
@ -68,7 +57,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
||||
logging.Log.WriteInfof("Sending to random cluster")
|
||||
interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
|
||||
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
||||
randomSubset = append(randomSubset, interCluster)
|
||||
}
|
||||
|
||||
@ -102,6 +91,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
// Check if any changes have occurred and trigger callbacks
|
||||
// if changes have occurred.
|
||||
// return s.manager.GetMonitor().Trigger()
|
||||
s.manager.GetMesh(meshId).SaveChanges()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,6 @@ func syncFunction(syncer Syncer) lib.TimerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
|
||||
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
|
||||
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
|
||||
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
|
||||
}
|
||||
|
@ -12,11 +12,3 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
|
||||
|
||||
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
|
||||
}
|
||||
|
||||
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
|
||||
timerFunc := func() error {
|
||||
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
|
||||
}
|
||||
|
||||
return *lib.NewTimer(timerFunc, 10)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user