diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 0a0cc14..1e5b580 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -449,7 +449,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error { } // DeleteRoutes deletes the specified routes -func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { +func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { @@ -467,7 +467,7 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { } for _, route := range routes { - err = routeMap.Map().Delete(route) + err = routeMap.Map().Delete(route.GetDestination().String()) } return err diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index b3b94ca..bd77585 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -320,7 +320,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route } // DeleteRoutes: deletes the routes from the node -func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error { +func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { if !m.store.Contains(nodeId) { return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) } @@ -331,8 +331,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) node := m.store.Get(nodeId) + changes := false + for _, route := range routes { - delete(node.Routes, route) + changes = true + delete(node.Routes, route.GetDestination().String()) + } + + if changes { + m.store.Put(nodeId, node) } return nil diff --git a/pkg/crdt/vector_clock.go b/pkg/crdt/vector_clock.go index 78882c0..584cef3 100644 --- a/pkg/crdt/vector_clock.go +++ b/pkg/crdt/vector_clock.go @@ -98,7 +98,12 @@ func (m *VectorClock[K]) Prune() { } func (m *VectorClock[K]) GetTimestamp(processId K) uint64 { - return m.vectors[m.hashFunc(m.processID)].lastUpdate + m.lock.RLock() + + lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate + + m.lock.RUnlock() + return lastUpdate } func (m *VectorClock[K]) Put(key K, value uint64) { diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index 053ef33..70d3a9a 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -7,6 +7,27 @@ func MapValues[K cmp.Ordered, V any](m map[K]V) []V { return MapValuesWithExclude(m, map[K]struct{}{}) } +type MapItemsEntry[K cmp.Ordered, V any] struct { + Key K + Value V +} + +func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] { + keys := MapKeys(m) + values := MapValues(m) + + vs := make([]MapItemsEntry[K, V], len(keys)) + + for index, _ := range keys { + vs[index] = MapItemsEntry[K, V]{ + Key: keys[index], + Value: values[index], + } + } + + return vs +} + func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V { values := make([]V, len(m)-len(exclude)) diff --git a/pkg/lib/rtnetlink.go b/pkg/lib/rtnetlink.go index f95b54b..3daa5ef 100644 --- a/pkg/lib/rtnetlink.go +++ b/pkg/lib/rtnetlink.go @@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error { family = unix.AF_INET } - attr := rtnetlink.RouteAttributes{ - Dst: dst.IP, - OutIface: uint32(iface.Index), - Gateway: gw, - } - - ones, _ := dst.Mask.Size() - - err = c.conn.Route.Replace(&rtnetlink.RouteMessage{ - Family: family, - Table: unix.RT_TABLE_MAIN, - Protocol: unix.RTPROT_BOOT, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - DstLength: uint8(ones), - Attributes: attr, - }) + routes, err := c.listRoutes(ifName, family) if err != nil { - return fmt.Errorf("failed to add route %w", err) + return err + } + + // If it already exists no need to add the route + if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool { + return prevRoute.Attributes.Dst.Equal(route.Destination.IP) && + prevRoute.Attributes.Gateway.Equal(route.Gateway) + }) { + attr := rtnetlink.RouteAttributes{ + Dst: dst.IP, + OutIface: uint32(iface.Index), + Gateway: gw, + } + + ones, _ := dst.Mask.Size() + + err = c.conn.Route.Replace(&rtnetlink.RouteMessage{ + Family: family, + Table: unix.RT_TABLE_MAIN, + Protocol: unix.RTPROT_BOOT, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + DstLength: uint8(ones), + Attributes: attr, + }) + + if err != nil { + return fmt.Errorf("failed to add route %w", err) + } } return nil diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 4cf04c6..bddfd3e 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -10,7 +10,6 @@ import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" - logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/route" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -35,7 +34,8 @@ type routeNode struct { route Route } -func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, +func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, + device *wgtypes.Device, peerToClients map[string][]net.IPNet, routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { @@ -66,7 +66,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev } // Else there is more than one candidate so consistently hash - pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, m.hashFunc) + pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc) } if pickedRoute.gateway == pubKey.String() { @@ -169,8 +169,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] } else if route.GetHopCount() < otherRoute[0].route.GetHopCount() { otherRoute[0] = rn } else if otherRoute[0].route.GetHopCount() == route.GetHopCount() { - logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount()) - logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount()) routes[destination] = append(otherRoute, rn) } } @@ -185,6 +183,22 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh return peer } +func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig { + peers := dev.Peers + peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool { + return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool { + return p1.PublicKey.String() == p2.PublicKey.String() + }) + }) + + return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig { + return wgtypes.PeerConfig{ + PublicKey: p.PublicKey, + Remove: true, + } + }) +} + type GetConfigParams struct { mesh MeshProvider peers []MeshNode @@ -198,11 +212,16 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes ula := &ip.ULABuilder{} meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId()) - routes := lib.Map(lib.MapKeys(params.routes), func(destination string) net.IPNet { - _, ipNet, _ := net.ParseCIDR(destination) - return *ipNet + routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode { + return lib.Filter(rns, func(rn routeNode) bool { + ip, _, _ := net.ParseCIDR(rn.gateway) + return meshNet.Contains(ip) + }) }) + routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet { + return *rs[0].route.GetDestination() + }) routes = append(routes, *meshNet) if err != nil { @@ -210,9 +229,7 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes } peer := m.getCorrespondingPeer(params.peers, self) - pubKey, _ := peer.GetPublicKey() - keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint()) @@ -291,7 +308,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C peerToClients[pubKey.String()] = append(clients, *n.GetWgHost()) if NodeEquals(self, peer) { - cfg, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes) + cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) if err != nil { return nil, err @@ -308,7 +325,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C continue } - peer, err := m.convertMeshNode(n, params.dev, peerToClients, params.routes) + peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) if err != nil { return nil, err @@ -319,15 +336,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C } cfg := wgtypes.Config{ - Peers: peerConfigs, - ReplacePeers: true, + Peers: peerConfigs, } err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...) return &cfg, err } -func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { +func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error { snap, err := mesh.GetMesh() if err != nil { @@ -357,7 +373,6 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { var cfg *wgtypes.Config = nil - routes := m.getRoutes(mesh) configParams := &GetConfigParams{ mesh: mesh, peers: peers, @@ -377,6 +392,9 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } + toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers) + cfg.Peers = append(cfg.Peers, toRemove...) + err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg) if err != nil { @@ -386,9 +404,36 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return nil } -func (m *WgMeshConfigApplyer) ApplyConfig() error { +func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { + allRoutes := make(map[string][]routeNode) + for _, mesh := range m.meshManager.GetMeshes() { - err := m.updateWgConf(mesh) + routes := m.getRoutes(mesh) + + for destination, route := range routes { + _, ok := allRoutes[destination] + + if !ok { + allRoutes[destination] = route + continue + } + + if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() { + allRoutes[destination] = append(allRoutes[destination], route...) + } else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() { + allRoutes[destination] = route + } + } + } + + return allRoutes +} + +func (m *WgMeshConfigApplyer) ApplyConfig() error { + allRoutes := m.getAllRoutes() + + for _, mesh := range m.meshManager.GetMeshes() { + err := m.updateWgConf(mesh, allRoutes) if err != nil { return err diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 60a142d..8fe4d33 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -276,7 +276,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { } s.Meshes[params.MeshId].AddNode(node) - return s.RouteManager.UpdateRoutes() + return nil } // LeaveMesh leaves the mesh network @@ -287,10 +287,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { return fmt.Errorf("mesh %s does not exist", meshId) } - var err error - - s.RouteManager.RemoveRoutes(meshId) - err = mesh.RemoveNode(s.HostParameters.GetPublicKey()) + err := mesh.RemoveNode(s.HostParameters.GetPublicKey()) if err != nil { return err diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 2354367..70ac341 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -6,12 +6,10 @@ import ( "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" - logging "github.com/tim-beatham/wgmesh/pkg/log" ) type RouteManager interface { UpdateRoutes() error - RemoveRoutes(meshId string) error } type RouteManagerImpl struct { @@ -21,7 +19,7 @@ type RouteManagerImpl struct { func (r *RouteManagerImpl) UpdateRoutes() error { meshes := r.meshManager.GetMeshes() - ulaBuilder := new(ip.ULABuilder) + routes := make(map[string][]Route) for _, mesh1 := range meshes { self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) @@ -30,13 +28,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - pubKey, err := self.GetPublicKey() - - if err != nil { - return err + if _, ok := routes[mesh1.GetMeshId()]; !ok { + routes[mesh1.GetMeshId()] = make([]Route, 0) } - routeMap, err := mesh1.GetRoutes(pubKey.String()) + routeMap, err := mesh1.GetRoutes(NodeID(self)) if err != nil { return err @@ -54,57 +50,62 @@ func (r *RouteManagerImpl) UpdateRoutes() error { } for _, mesh2 := range meshes { + routeValues, ok := routes[mesh2.GetMeshId()] + + if !ok { + routeValues = make([]Route, 0) + } + if mesh1 == mesh2 { continue } - ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId()) + mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId()) - if err != nil { - logging.Log.WriteErrorf(err.Error()) - return err - } - - routes := lib.MapValues(routeMap) - - err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{ - Destination: ipNet, + routeValues = append(routeValues, &RouteStub{ + Destination: mesh1IpNet, HopCount: 0, - Path: make([]string, 0), - })...) + Path: []string{mesh1.GetMeshId()}, + }) - if err != nil { - return err + routeValues = append(routeValues, lib.MapValues(routeMap)...) + mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId()) + routeValues = lib.Filter(routeValues, func(r Route) bool { + pathNotMesh := func(s string) bool { + return s == mesh2.GetMeshId() + } + + // Ensure that the route does not see it's own IP + return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh) + }) + + routes[mesh2.GetMeshId()] = routeValues + } + } + + // Calculate the set different of each, working out routes to remove and to keep. + for meshId, meshRoutes := range routes { + mesh := r.meshManager.GetMesh(meshId) + self, _ := r.meshManager.GetSelf(meshId) + toRemove := make([]Route, 0) + + prevRoutes, _ := mesh.GetRoutes(NodeID(self)) + + for _, route := range prevRoutes { + if !lib.Contains(meshRoutes, func(r Route) bool { + return RouteEquals(r, route) + }) { + toRemove = append(toRemove, route) } } + + mesh.RemoveRoutes(NodeID(self), toRemove...) + mesh.AddRoutes(NodeID(self), meshRoutes...) } return nil } -// removeRoutes: removes all meshes we are no longer a part of -func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { - ulaBuilder := new(ip.ULABuilder) - meshes := r.meshManager.GetMeshes() - - ipNet, err := ulaBuilder.GetIPNet(meshId) - - if err != nil { - return err - } - - for _, mesh1 := range meshes { - self, err := r.meshManager.GetSelf(meshId) - - if err != nil { - return err - } - - mesh1.RemoveRoutes(NodeID(self), ipNet.String()) - } - return nil -} - func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager { return &RouteManagerImpl{meshManager: m, conf: conf} } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 0cf622c..96811c5 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -126,7 +126,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { } // RemoveRoutes implements MeshProvider. -func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { +func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error { return nil } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 72556d5..16b9d9c 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -4,6 +4,7 @@ package mesh import ( "net" + "slices" "github.com/tim-beatham/wgmesh/pkg/conf" "golang.zx2c4.com/wireguard/wgctrl" @@ -19,6 +20,12 @@ type Route interface { GetPath() []string } +func RouteEquals(r1, r2 Route) bool { + return r1.GetDestination().String() == r2.GetDestination().String() && + r1.GetHopCount() == r2.GetHopCount() && + slices.Equal(r1.GetPath(), r2.GetPath()) +} + type RouteStub struct { Destination *net.IPNet HopCount int @@ -71,11 +78,6 @@ func NodeEquals(node1, node2 MeshNode) bool { return key1.String() == key2.String() } -func RouteEquals(route1, route2 Route) bool { - return route1.GetDestination().String() == route2.GetDestination().String() && - route1.GetHopCount() == route2.GetHopCount() -} - func NodeID(node MeshNode) string { key, _ := node.GetPublicKey() return key.String() @@ -116,7 +118,7 @@ type MeshProvider interface { // AddRoutes: adds routes to the given node AddRoutes(nodeId string, route ...Route) error // DeleteRoutes: deletes the routes from the node - RemoveRoutes(nodeId string, route ...string) error + RemoveRoutes(nodeId string, route ...Route) error // GetSyncer: returns the automerge syncer for sync GetSyncer() MeshSyncer // GetNode get a particular not within the mesh diff --git a/pkg/route/route.go b/pkg/route/route.go index 976b6c4..11de7d7 100644 --- a/pkg/route/route.go +++ b/pkg/route/route.go @@ -19,11 +19,7 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) return err } - ip6Routes := lib.Filter(routes, func(r lib.Route) bool { - return r.Destination.IP.To4() == nil - }) - - err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...) + err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...) if err != nil { return err diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 95efe39..d6e3db8 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -30,15 +30,12 @@ type SyncerImpl struct { // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { - self, err := s.manager.GetSelf(meshId) - - if err != nil { - return err - } + // Self can be nil if the node is removed + self, _ := s.manager.GetSelf(meshId) s.manager.GetMesh(meshId).Prune() - if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { + if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { logging.Log.WriteInfof("No changes for %s", meshId) return nil } @@ -52,10 +49,16 @@ func (s *SyncerImpl) Sync(meshId string) error { nodeNames := s.manager.GetMesh(meshId).GetPeers() + if self != nil { + nodeNames = lib.Filter(nodeNames, func(s string) bool { + return s != mesh.NodeID(self) + }) + } + var gossipNodes []string // Clients always pings its peer for configuration - if self.GetType() == conf.CLIENT_ROLE { + if self != nil && self.GetType() == conf.CLIENT_ROLE { keyFunc := lib.HashString bucketFunc := lib.HashString @@ -108,7 +111,7 @@ func (s *SyncerImpl) Sync(meshId string) error { s.lastSync = uint64(time.Now().Unix()) logging.Log.WriteInfof("UPDATING WG CONF") - err = s.manager.ApplyConfig() + err := s.manager.ApplyConfig() if err != nil { logging.Log.WriteInfof("Failed to update config %w", err)