diff --git a/pkg/lib/hashing.go b/pkg/lib/hashing.go index 8bb40ab..824362f 100644 --- a/pkg/lib/hashing.go +++ b/pkg/lib/hashing.go @@ -18,7 +18,7 @@ func HashString(value string) int { // ConsistentHash implementation. Traverse the values until we find a key // less than ours. -func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V { +func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V { if len(values) == 0 { panic("values is empty") } @@ -26,7 +26,7 @@ func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V { vs := Map(values, func(v V) consistentHashRecord[V] { return consistentHashRecord[V]{ v, - keyFunc(v), + bucketFunc(v), } }) diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 9424363..d1c4327 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -7,6 +7,7 @@ import ( "time" "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/route" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -26,7 +27,19 @@ type WgMeshConfigApplyer struct { routeInstaller route.RouteInstaller } -func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) { +type routeNode struct { + gateway string + route Route +} + +func (r *routeNode) equals(route2 *routeNode) bool { + return r.gateway == route2.gateway && RouteEquals(r.route, route2.route) +} + +func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, + peerToClients map[string][]net.IPNet, + routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { + endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) if err != nil { @@ -42,16 +55,36 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev allowedips := make([]net.IPNet, 1) allowedips[0] = *node.GetWgHost() - for _, route := range node.GetRoutes() { - allowedips = append(allowedips, *route.GetDestination()) - } - clients, ok := peerToClients[node.GetWgHost().String()] if ok { allowedips = append(allowedips, clients...) } + for _, route := range node.GetRoutes() { + bestRoutes := routes[route.GetDestination().String()] + + if len(bestRoutes) == 1 { + allowedips = append(allowedips, *route.GetDestination()) + } else if len(bestRoutes) > 1 { + keyFunc := func(mn MeshNode) int { + pubKey, _ := mn.GetPublicKey() + return lib.HashString(pubKey.String()) + } + + bucketFunc := func(rn routeNode) int { + return lib.HashString(rn.gateway) + } + + // Else there is more than one candidate so consistently hash + pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc) + + if pickedRoute.gateway == pubKey.String() { + allowedips = append(allowedips, *route.GetDestination()) + } + } + } + keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool { @@ -73,6 +106,37 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev return &peerConfig, nil } +// getRoutes: finds the routes with the least hop distance. If more than one route exists +// consistently hash to evenly spread the distribution of traffic +func (m *WgMeshConfigApplyer) getRoutes(mesh MeshSnapshot) map[string][]routeNode { + routes := make(map[string][]routeNode) + + for _, node := range mesh.GetNodes() { + for _, route := range node.GetRoutes() { + destination := route.GetDestination().String() + otherRoute, ok := routes[destination] + pubKey, _ := node.GetPublicKey() + + rn := routeNode{ + gateway: pubKey.String(), + route: route, + } + + if !ok { + otherRoute = make([]routeNode, 1) + otherRoute[0] = rn + routes[destination] = otherRoute + } else if otherRoute[0].route.GetHopCount() > route.GetHopCount() { + otherRoute[0] = rn + } else if otherRoute[0].route.GetHopCount() == route.GetHopCount() { + routes[destination] = append(otherRoute, rn) + } + } + } + + return routes +} + func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { snap, err := mesh.GetMesh() @@ -96,26 +160,19 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } peerToClients := make(map[string][]net.IPNet) - - routes := make([]lib.Route, 1) + routes := m.getRoutes(snap) + installedRoutes := make([]lib.Route, 0) for _, n := range nodes { if NodeEquals(n, self) { continue } - for _, route := range n.GetRoutes() { - - routes = append(routes, lib.Route{ - Gateway: n.GetWgHost().IP, - Destination: *route.GetDestination(), - }) - } - if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE { - peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int { + hashFunc := func(mn MeshNode) int { return lib.HashString(mn.GetWgHost().String()) - }) + } + peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc) clients, ok := peerToClients[peer.GetWgHost().String()] @@ -129,20 +186,31 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } dev, _ := mesh.GetDevice() - - peer, err := m.convertMeshNode(n, dev, peerToClients) + peer, err := m.convertMeshNode(n, dev, peerToClients, routes) if err != nil { return err } + for _, route := range peer.AllowedIPs { + ula := &ip.ULABuilder{} + ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) + + if !ipNet.Contains(route.IP) { + + installedRoutes = append(installedRoutes, lib.Route{ + Gateway: n.GetWgHost().IP, + Destination: route, + }) + } + } + peerConfigs[count] = *peer count++ } cfg := wgtypes.Config{ - Peers: peerConfigs, - ReplacePeers: true, + Peers: peerConfigs, } dev, err := mesh.GetDevice() @@ -151,7 +219,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } - err = m.routeInstaller.InstallRoutes(dev.Name, routes...) + err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) if err != nil { return err diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 7e4350d..f29b03e 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -64,6 +64,11 @@ func NodeEquals(node1, node2 MeshNode) bool { return key1.String() == key2.String() } +func RouteEquals(route1, route2 Route) bool { + return route1.GetDestination().String() == route2.GetDestination().String() && + route1.GetHopCount() == route2.GetHopCount() +} + func NodeID(node MeshNode) string { key, _ := node.GetPublicKey() return key.String()