mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-01-08 14:29:00 +01:00
Merge pull request #52 from tim-beatham/51-bufix-not-removing-when-withdrawn
51-bugfix-routes-not-removing-when-withdrawn
This commit is contained in:
commit
4a8a39601f
@ -449,7 +449,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes deletes the specified routes
|
// DeleteRoutes deletes the specified routes
|
||||||
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -467,7 +467,7 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
err = routeMap.Map().Delete(route)
|
err = routeMap.Map().Delete(route.GetDestination().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
@ -320,7 +320,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
if !m.store.Contains(nodeId) {
|
if !m.store.Contains(nodeId) {
|
||||||
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
}
|
}
|
||||||
@ -331,8 +331,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string)
|
|||||||
|
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
changes := false
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
delete(node.Routes, route)
|
changes = true
|
||||||
|
delete(node.Routes, route.GetDestination().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes {
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -98,7 +98,12 @@ func (m *VectorClock[K]) Prune() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
||||||
return m.vectors[m.hashFunc(m.processID)].lastUpdate
|
m.lock.RLock()
|
||||||
|
|
||||||
|
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return lastUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *VectorClock[K]) Put(key K, value uint64) {
|
func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||||
|
@ -7,6 +7,27 @@ func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
|
|||||||
return MapValuesWithExclude(m, map[K]struct{}{})
|
return MapValuesWithExclude(m, map[K]struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MapItemsEntry[K cmp.Ordered, V any] struct {
|
||||||
|
Key K
|
||||||
|
Value V
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
|
||||||
|
keys := MapKeys(m)
|
||||||
|
values := MapValues(m)
|
||||||
|
|
||||||
|
vs := make([]MapItemsEntry[K, V], len(keys))
|
||||||
|
|
||||||
|
for index, _ := range keys {
|
||||||
|
vs[index] = MapItemsEntry[K, V]{
|
||||||
|
Key: keys[index],
|
||||||
|
Value: values[index],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
||||||
values := make([]V, len(m)-len(exclude))
|
values := make([]V, len(m)-len(exclude))
|
||||||
|
|
||||||
|
@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
|
|||||||
family = unix.AF_INET
|
family = unix.AF_INET
|
||||||
}
|
}
|
||||||
|
|
||||||
attr := rtnetlink.RouteAttributes{
|
routes, err := c.listRoutes(ifName, family)
|
||||||
Dst: dst.IP,
|
|
||||||
OutIface: uint32(iface.Index),
|
|
||||||
Gateway: gw,
|
|
||||||
}
|
|
||||||
|
|
||||||
ones, _ := dst.Mask.Size()
|
|
||||||
|
|
||||||
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
|
||||||
Family: family,
|
|
||||||
Table: unix.RT_TABLE_MAIN,
|
|
||||||
Protocol: unix.RTPROT_BOOT,
|
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
|
||||||
Type: unix.RTN_UNICAST,
|
|
||||||
DstLength: uint8(ones),
|
|
||||||
Attributes: attr,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add route %w", err)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it already exists no need to add the route
|
||||||
|
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
|
||||||
|
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
|
||||||
|
prevRoute.Attributes.Gateway.Equal(route.Gateway)
|
||||||
|
}) {
|
||||||
|
attr := rtnetlink.RouteAttributes{
|
||||||
|
Dst: dst.IP,
|
||||||
|
OutIface: uint32(iface.Index),
|
||||||
|
Gateway: gw,
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := dst.Mask.Size()
|
||||||
|
|
||||||
|
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
||||||
|
Family: family,
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
Protocol: unix.RTPROT_BOOT,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
DstLength: uint8(ones),
|
||||||
|
Attributes: attr,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@ -35,7 +34,8 @@ type routeNode struct {
|
|||||||
route Route
|
route Route
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
|
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
|
||||||
|
device *wgtypes.Device,
|
||||||
peerToClients map[string][]net.IPNet,
|
peerToClients map[string][]net.IPNet,
|
||||||
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Else there is more than one candidate so consistently hash
|
// Else there is more than one candidate so consistently hash
|
||||||
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, m.hashFunc)
|
pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pickedRoute.gateway == pubKey.String() {
|
if pickedRoute.gateway == pubKey.String() {
|
||||||
@ -169,8 +169,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
||||||
otherRoute[0] = rn
|
otherRoute[0] = rn
|
||||||
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
||||||
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
|
|
||||||
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
|
|
||||||
routes[destination] = append(otherRoute, rn)
|
routes[destination] = append(otherRoute, rn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -185,6 +183,22 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh
|
|||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
|
||||||
|
peers := dev.Peers
|
||||||
|
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
|
||||||
|
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
|
||||||
|
return p1.PublicKey.String() == p2.PublicKey.String()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
|
||||||
|
return wgtypes.PeerConfig{
|
||||||
|
PublicKey: p.PublicKey,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type GetConfigParams struct {
|
type GetConfigParams struct {
|
||||||
mesh MeshProvider
|
mesh MeshProvider
|
||||||
peers []MeshNode
|
peers []MeshNode
|
||||||
@ -198,11 +212,16 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
|
|||||||
ula := &ip.ULABuilder{}
|
ula := &ip.ULABuilder{}
|
||||||
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
|
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
|
||||||
|
|
||||||
routes := lib.Map(lib.MapKeys(params.routes), func(destination string) net.IPNet {
|
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
|
||||||
_, ipNet, _ := net.ParseCIDR(destination)
|
return lib.Filter(rns, func(rn routeNode) bool {
|
||||||
return *ipNet
|
ip, _, _ := net.ParseCIDR(rn.gateway)
|
||||||
|
return meshNet.Contains(ip)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
|
||||||
|
return *rs[0].route.GetDestination()
|
||||||
|
})
|
||||||
routes = append(routes, *meshNet)
|
routes = append(routes, *meshNet)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -210,9 +229,7 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
|
|||||||
}
|
}
|
||||||
|
|
||||||
peer := m.getCorrespondingPeer(params.peers, self)
|
peer := m.getCorrespondingPeer(params.peers, self)
|
||||||
|
|
||||||
pubKey, _ := peer.GetPublicKey()
|
pubKey, _ := peer.GetPublicKey()
|
||||||
|
|
||||||
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
||||||
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
||||||
|
|
||||||
@ -291,7 +308,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
|
|||||||
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
||||||
|
|
||||||
if NodeEquals(self, peer) {
|
if NodeEquals(self, peer) {
|
||||||
cfg, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
|
cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -308,7 +325,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
|
peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -319,15 +336,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerConfigs,
|
Peers: peerConfigs,
|
||||||
ReplacePeers: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
||||||
return &cfg, err
|
return &cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
|
||||||
snap, err := mesh.GetMesh()
|
snap, err := mesh.GetMesh()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -357,7 +373,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
|
|
||||||
var cfg *wgtypes.Config = nil
|
var cfg *wgtypes.Config = nil
|
||||||
|
|
||||||
routes := m.getRoutes(mesh)
|
|
||||||
configParams := &GetConfigParams{
|
configParams := &GetConfigParams{
|
||||||
mesh: mesh,
|
mesh: mesh,
|
||||||
peers: peers,
|
peers: peers,
|
||||||
@ -377,6 +392,9 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
|
||||||
|
cfg.Peers = append(cfg.Peers, toRemove...)
|
||||||
|
|
||||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -386,9 +404,36 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
|
||||||
|
allRoutes := make(map[string][]routeNode)
|
||||||
|
|
||||||
for _, mesh := range m.meshManager.GetMeshes() {
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
err := m.updateWgConf(mesh)
|
routes := m.getRoutes(mesh)
|
||||||
|
|
||||||
|
for destination, route := range routes {
|
||||||
|
_, ok := allRoutes[destination]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = append(allRoutes[destination], route...)
|
||||||
|
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||||
|
allRoutes := m.getAllRoutes()
|
||||||
|
|
||||||
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
|
err := m.updateWgConf(mesh, allRoutes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -276,7 +276,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.Meshes[params.MeshId].AddNode(node)
|
s.Meshes[params.MeshId].AddNode(node)
|
||||||
return s.RouteManager.UpdateRoutes()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LeaveMesh leaves the mesh network
|
// LeaveMesh leaves the mesh network
|
||||||
@ -287,10 +287,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
|||||||
return fmt.Errorf("mesh %s does not exist", meshId)
|
return fmt.Errorf("mesh %s does not exist", meshId)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
||||||
|
|
||||||
s.RouteManager.RemoveRoutes(meshId)
|
|
||||||
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -6,12 +6,10 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteManager interface {
|
type RouteManager interface {
|
||||||
UpdateRoutes() error
|
UpdateRoutes() error
|
||||||
RemoveRoutes(meshId string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteManagerImpl struct {
|
type RouteManagerImpl struct {
|
||||||
@ -21,7 +19,7 @@ type RouteManagerImpl struct {
|
|||||||
|
|
||||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||||
meshes := r.meshManager.GetMeshes()
|
meshes := r.meshManager.GetMeshes()
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
routes := make(map[string][]Route)
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
for _, mesh1 := range meshes {
|
||||||
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
||||||
@ -30,13 +28,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := self.GetPublicKey()
|
if _, ok := routes[mesh1.GetMeshId()]; !ok {
|
||||||
|
routes[mesh1.GetMeshId()] = make([]Route, 0)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
routeMap, err := mesh1.GetRoutes(pubKey.String())
|
routeMap, err := mesh1.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -54,57 +50,62 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, mesh2 := range meshes {
|
for _, mesh2 := range meshes {
|
||||||
|
routeValues, ok := routes[mesh2.GetMeshId()]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
routeValues = make([]Route, 0)
|
||||||
|
}
|
||||||
|
|
||||||
if mesh1 == mesh2 {
|
if mesh1 == mesh2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
|
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, &RouteStub{
|
||||||
logging.Log.WriteErrorf(err.Error())
|
Destination: mesh1IpNet,
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := lib.MapValues(routeMap)
|
|
||||||
|
|
||||||
err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{
|
|
||||||
Destination: ipNet,
|
|
||||||
HopCount: 0,
|
HopCount: 0,
|
||||||
Path: make([]string, 0),
|
Path: []string{mesh1.GetMeshId()},
|
||||||
})...)
|
})
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, lib.MapValues(routeMap)...)
|
||||||
return err
|
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
|
||||||
|
routeValues = lib.Filter(routeValues, func(r Route) bool {
|
||||||
|
pathNotMesh := func(s string) bool {
|
||||||
|
return s == mesh2.GetMeshId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the route does not see it's own IP
|
||||||
|
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh)
|
||||||
|
})
|
||||||
|
|
||||||
|
routes[mesh2.GetMeshId()] = routeValues
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the set different of each, working out routes to remove and to keep.
|
||||||
|
for meshId, meshRoutes := range routes {
|
||||||
|
mesh := r.meshManager.GetMesh(meshId)
|
||||||
|
self, _ := r.meshManager.GetSelf(meshId)
|
||||||
|
toRemove := make([]Route, 0)
|
||||||
|
|
||||||
|
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
|
for _, route := range prevRoutes {
|
||||||
|
if !lib.Contains(meshRoutes, func(r Route) bool {
|
||||||
|
return RouteEquals(r, route)
|
||||||
|
}) {
|
||||||
|
toRemove = append(toRemove, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mesh.RemoveRoutes(NodeID(self), toRemove...)
|
||||||
|
mesh.AddRoutes(NodeID(self), meshRoutes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeRoutes: removes all meshes we are no longer a part of
|
|
||||||
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
|
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
|
||||||
meshes := r.meshManager.GetMeshes()
|
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
|
||||||
self, err := r.meshManager.GetSelf(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
|
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
|
||||||
return &RouteManagerImpl{meshManager: m, conf: conf}
|
return &RouteManagerImpl{meshManager: m, conf: conf}
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRoutes implements MeshProvider.
|
// RemoveRoutes implements MeshProvider.
|
||||||
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
|
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ package mesh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@ -19,6 +20,12 @@ type Route interface {
|
|||||||
GetPath() []string
|
GetPath() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RouteEquals(r1, r2 Route) bool {
|
||||||
|
return r1.GetDestination().String() == r2.GetDestination().String() &&
|
||||||
|
r1.GetHopCount() == r2.GetHopCount() &&
|
||||||
|
slices.Equal(r1.GetPath(), r2.GetPath())
|
||||||
|
}
|
||||||
|
|
||||||
type RouteStub struct {
|
type RouteStub struct {
|
||||||
Destination *net.IPNet
|
Destination *net.IPNet
|
||||||
HopCount int
|
HopCount int
|
||||||
@ -71,11 +78,6 @@ func NodeEquals(node1, node2 MeshNode) bool {
|
|||||||
return key1.String() == key2.String()
|
return key1.String() == key2.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RouteEquals(route1, route2 Route) bool {
|
|
||||||
return route1.GetDestination().String() == route2.GetDestination().String() &&
|
|
||||||
route1.GetHopCount() == route2.GetHopCount()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodeID(node MeshNode) string {
|
func NodeID(node MeshNode) string {
|
||||||
key, _ := node.GetPublicKey()
|
key, _ := node.GetPublicKey()
|
||||||
return key.String()
|
return key.String()
|
||||||
@ -116,7 +118,7 @@ type MeshProvider interface {
|
|||||||
// AddRoutes: adds routes to the given node
|
// AddRoutes: adds routes to the given node
|
||||||
AddRoutes(nodeId string, route ...Route) error
|
AddRoutes(nodeId string, route ...Route) error
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
RemoveRoutes(nodeId string, route ...string) error
|
RemoveRoutes(nodeId string, route ...Route) error
|
||||||
// GetSyncer: returns the automerge syncer for sync
|
// GetSyncer: returns the automerge syncer for sync
|
||||||
GetSyncer() MeshSyncer
|
GetSyncer() MeshSyncer
|
||||||
// GetNode get a particular not within the mesh
|
// GetNode get a particular not within the mesh
|
||||||
|
@ -19,11 +19,7 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ip6Routes := lib.Filter(routes, func(r lib.Route) bool {
|
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
|
||||||
return r.Destination.IP.To4() == nil
|
|
||||||
})
|
|
||||||
|
|
||||||
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -30,15 +30,12 @@ type SyncerImpl struct {
|
|||||||
|
|
||||||
// Sync: Sync random nodes
|
// Sync: Sync random nodes
|
||||||
func (s *SyncerImpl) Sync(meshId string) error {
|
func (s *SyncerImpl) Sync(meshId string) error {
|
||||||
self, err := s.manager.GetSelf(meshId)
|
// Self can be nil if the node is removed
|
||||||
|
self, _ := s.manager.GetSelf(meshId)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.manager.GetMesh(meshId).Prune()
|
s.manager.GetMesh(meshId).Prune()
|
||||||
|
|
||||||
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
||||||
logging.Log.WriteInfof("No changes for %s", meshId)
|
logging.Log.WriteInfof("No changes for %s", meshId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -52,10 +49,16 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
|
|
||||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
||||||
|
|
||||||
|
if self != nil {
|
||||||
|
nodeNames = lib.Filter(nodeNames, func(s string) bool {
|
||||||
|
return s != mesh.NodeID(self)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
var gossipNodes []string
|
var gossipNodes []string
|
||||||
|
|
||||||
// Clients always pings its peer for configuration
|
// Clients always pings its peer for configuration
|
||||||
if self.GetType() == conf.CLIENT_ROLE {
|
if self != nil && self.GetType() == conf.CLIENT_ROLE {
|
||||||
keyFunc := lib.HashString
|
keyFunc := lib.HashString
|
||||||
bucketFunc := lib.HashString
|
bucketFunc := lib.HashString
|
||||||
|
|
||||||
@ -108,7 +111,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
s.lastSync = uint64(time.Now().Unix())
|
s.lastSync = uint64(time.Now().Unix())
|
||||||
|
|
||||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||||
err = s.manager.ApplyConfig()
|
err := s.manager.ApplyConfig()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||||
|
Loading…
Reference in New Issue
Block a user