34-fix-routing

- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
This commit is contained in:
Tim Beatham 2023-11-27 15:56:30 +00:00
parent aef8b59f22
commit a2517a1e72
3 changed files with 97 additions and 24 deletions

View File

@ -18,7 +18,7 @@ func HashString(value string) int {
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
@ -26,7 +26,7 @@ func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
keyFunc(v),
bucketFunc(v),
}
})

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -26,7 +27,19 @@ type WgMeshConfigApplyer struct {
routeInstaller route.RouteInstaller
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) {
type routeNode struct {
gateway string
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
@ -42,16 +55,36 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
for _, route := range node.GetRoutes() {
allowedips = append(allowedips, *route.GetDestination())
}
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()]
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
@ -73,6 +106,37 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
return &peerConfig, nil
}
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(mesh MeshSnapshot) map[string][]routeNode {
routes := make(map[string][]routeNode)
for _, node := range mesh.GetNodes() {
for _, route := range node.GetRoutes() {
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
pubKey, _ := node.GetPublicKey()
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if otherRoute[0].route.GetHopCount() > route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
@ -96,26 +160,19 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
}
peerToClients := make(map[string][]net.IPNet)
routes := make([]lib.Route, 1)
routes := m.getRoutes(snap)
installedRoutes := make([]lib.Route, 0)
for _, n := range nodes {
if NodeEquals(n, self) {
continue
}
for _, route := range n.GetRoutes() {
routes = append(routes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: *route.GetDestination(),
})
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
})
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
@ -129,20 +186,31 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients)
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer
count++
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
}
dev, err := mesh.GetDevice()
@ -151,7 +219,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
err = m.routeInstaller.InstallRoutes(dev.Name, routes...)
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
if err != nil {
return err

View File

@ -64,6 +64,11 @@ func NodeEquals(node1, node2 MeshNode) bool {
return key1.String() == key2.String()
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()