Merge pull request #52 from tim-beatham/51-bufix-not-removing-when-withdrawn

51-bugfix-routes-not-removing-when-withdrawn
This commit is contained in:
Tim Beatham 2023-12-10 15:13:57 +00:00 committed by GitHub
commit 4a8a39601f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 200 additions and 111 deletions

View File

@ -449,7 +449,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
@ -467,7 +467,7 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Delete(route)
err = routeMap.Map().Delete(route.GetDestination().String())
}
return err

View File

@ -320,7 +320,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
}
// DeleteRoutes: deletes the routes from the node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
}
@ -331,8 +331,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string)
node := m.store.Get(nodeId)
changes := false
for _, route := range routes {
delete(node.Routes, route)
changes = true
delete(node.Routes, route.GetDestination().String())
}
if changes {
m.store.Put(nodeId, node)
}
return nil

View File

@ -98,7 +98,12 @@ func (m *VectorClock[K]) Prune() {
}
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RLock()
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
m.lock.RUnlock()
return lastUpdate
}
func (m *VectorClock[K]) Put(key K, value uint64) {

View File

@ -7,6 +7,27 @@ func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
}
type MapItemsEntry[K cmp.Ordered, V any] struct {
Key K
Value V
}
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
keys := MapKeys(m)
values := MapValues(m)
vs := make([]MapItemsEntry[K, V], len(keys))
for index, _ := range keys {
vs[index] = MapItemsEntry[K, V]{
Key: keys[index],
Value: values[index],
}
}
return vs
}
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
values := make([]V, len(m)-len(exclude))

View File

@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
routes, err := c.listRoutes(ifName, family)
if err != nil {
return fmt.Errorf("failed to add route %w", err)
return err
}
// If it already exists no need to add the route
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
prevRoute.Attributes.Gateway.Equal(route.Gateway)
}) {
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
}
return nil

View File

@ -10,7 +10,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -35,7 +34,8 @@ type routeNode struct {
route Route
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
@ -66,7 +66,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
}
// Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, m.hashFunc)
pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc)
}
if pickedRoute.gateway == pubKey.String() {
@ -169,8 +169,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
routes[destination] = append(otherRoute, rn)
}
}
@ -185,6 +183,22 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh
return peer
}
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
return p1.PublicKey.String() == p2.PublicKey.String()
})
})
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
return wgtypes.PeerConfig{
PublicKey: p.PublicKey,
Remove: true,
}
})
}
type GetConfigParams struct {
mesh MeshProvider
peers []MeshNode
@ -198,11 +212,16 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routes := lib.Map(lib.MapKeys(params.routes), func(destination string) net.IPNet {
_, ipNet, _ := net.ParseCIDR(destination)
return *ipNet
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway)
return meshNet.Contains(ip)
})
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination()
})
routes = append(routes, *meshNet)
if err != nil {
@ -210,9 +229,7 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
}
peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
@ -291,7 +308,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
if err != nil {
return nil, err
@ -308,7 +325,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
continue
}
peer, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes)
peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
if err != nil {
return nil, err
@ -319,15 +336,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
Peers: peerConfigs,
}
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
return &cfg, err
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh()
if err != nil {
@ -357,7 +373,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
var cfg *wgtypes.Config = nil
routes := m.getRoutes(mesh)
configParams := &GetConfigParams{
mesh: mesh,
peers: peers,
@ -377,6 +392,9 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
cfg.Peers = append(cfg.Peers, toRemove...)
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
if err != nil {
@ -386,9 +404,36 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return nil
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh)
routes := m.getRoutes(mesh)
for destination, route := range routes {
_, ok := allRoutes[destination]
if !ok {
allRoutes[destination] = route
continue
}
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
allRoutes[destination] = append(allRoutes[destination], route...)
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
allRoutes[destination] = route
}
}
}
return allRoutes
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes := m.getAllRoutes()
for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes)
if err != nil {
return err

View File

@ -276,7 +276,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
return nil
}
// LeaveMesh leaves the mesh network
@ -287,10 +287,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return fmt.Errorf("mesh %s does not exist", meshId)
}
var err error
s.RouteManager.RemoveRoutes(meshId)
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil {
return err

View File

@ -6,12 +6,10 @@ import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type RouteManager interface {
UpdateRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
@ -21,7 +19,7 @@ type RouteManagerImpl struct {
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes()
ulaBuilder := new(ip.ULABuilder)
routes := make(map[string][]Route)
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
@ -30,13 +28,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
if _, ok := routes[mesh1.GetMeshId()]; !ok {
routes[mesh1.GetMeshId()] = make([]Route, 0)
}
routeMap, err := mesh1.GetRoutes(pubKey.String())
routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil {
return err
@ -54,57 +50,62 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
}
for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()]
if !ok {
routeValues = make([]Route, 0)
}
if mesh1 == mesh2 {
continue
}
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
if err != nil {
logging.Log.WriteErrorf(err.Error())
return err
}
routes := lib.MapValues(routeMap)
err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{
Destination: ipNet,
routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet,
HopCount: 0,
Path: make([]string, 0),
})...)
Path: []string{mesh1.GetMeshId()},
})
if err != nil {
return err
routeValues = append(routeValues, lib.MapValues(routeMap)...)
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
routeValues = lib.Filter(routeValues, func(r Route) bool {
pathNotMesh := func(s string) bool {
return s == mesh2.GetMeshId()
}
// Ensure that the route does not see it's own IP
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh)
})
routes[mesh2.GetMeshId()] = routeValues
}
}
// Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId)
self, _ := r.meshManager.GetSelf(meshId)
toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route)
}) {
toRemove = append(toRemove, route)
}
}
mesh.RemoveRoutes(NodeID(self), toRemove...)
mesh.AddRoutes(NodeID(self), meshRoutes...)
}
return nil
}
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
}
return nil
}
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
return &RouteManagerImpl{meshManager: m, conf: conf}
}

View File

@ -126,7 +126,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
return nil
}

View File

@ -4,6 +4,7 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
@ -19,6 +20,12 @@ type Route interface {
GetPath() []string
}
func RouteEquals(r1, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath())
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
@ -71,11 +78,6 @@ func NodeEquals(node1, node2 MeshNode) bool {
return key1.String() == key2.String()
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
@ -116,7 +118,7 @@ type MeshProvider interface {
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
RemoveRoutes(nodeId string, route ...Route) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh

View File

@ -19,11 +19,7 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
return err
}
ip6Routes := lib.Filter(routes, func(r lib.Route) bool {
return r.Destination.IP.To4() == nil
})
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...)
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err

View File

@ -30,15 +30,12 @@ type SyncerImpl struct {
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
// Self can be nil if the node is removed
self, _ := s.manager.GetSelf(meshId)
s.manager.GetMesh(meshId).Prune()
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
@ -52,10 +49,16 @@ func (s *SyncerImpl) Sync(meshId string) error {
nodeNames := s.manager.GetMesh(meshId).GetPeers()
if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool {
return s != mesh.NodeID(self)
})
}
var gossipNodes []string
// Clients always pings its peer for configuration
if self.GetType() == conf.CLIENT_ROLE {
if self != nil && self.GetType() == conf.CLIENT_ROLE {
keyFunc := lib.HashString
bucketFunc := lib.HashString
@ -108,7 +111,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
s.lastSync = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err = s.manager.ApplyConfig()
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)