From d8e156f13fd0cda89131460beeb804587a400cd6 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Mon, 27 Nov 2023 18:55:41 +0000 Subject: [PATCH] 36-add-route-path-into-route-object Added the route path into the route object so that we can see what meshes packets are routed across. --- cmd/wg-mesh/main.go | 10 +++- pkg/api/apiserver.go | 11 ++-- pkg/api/types.go | 4 +- pkg/automerge/automerge.go | 18 ++++-- pkg/automerge/types.go | 4 +- pkg/ctrlserver/ctrltypes.go | 7 ++- pkg/lib/conv.go | 10 ++++ pkg/mesh/config.go | 31 ++++++++-- pkg/mesh/route.go | 109 +----------------------------------- pkg/mesh/types.go | 7 +++ pkg/query/query.go | 3 + pkg/robin/requester.go | 7 ++- 12 files changed, 88 insertions(+), 133 deletions(-) diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 498d5c0..7462ceb 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -8,7 +8,9 @@ import ( "time" "github.com/akamensky/argparse" + "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" + "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" ) @@ -93,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) { fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("Wg IP: " + node.WgHost) - fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())) + fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()) - advertiseRoutes := strings.Join(node.Routes, ",") + mapFunc := func(r ctrlserver.MeshRoute) string { + return r.Destination + } + + advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",") fmt.Printf("Routes: %s\n", advertiseRoutes) fmt.Println("---") diff --git a/pkg/api/apiserver.go b/pkg/api/apiserver.go index cd132a3..dd3cf93 100644 --- a/pkg/api/apiserver.go +++ b/pkg/api/apiserver.go @@ -30,15 +30,14 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route { routes := make([]Route, len(meshNode.Routes)) for index, route := range meshNode.Routes { - word, err := s.words.Convert(route) - if err != nil { - fmt.Println(err.Error()) + if route.Path == nil { + route.Path = make([]string, 0) } routes[index] = Route{ - Prefix: route, - RouteId: word, + Prefix: route.Destination, + Path: route.Path, } } @@ -47,7 +46,7 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route { func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode { if meshNode.Routes == nil { - meshNode.Routes = make([]string, 0) + meshNode.Routes = make([]ctrlserver.MeshRoute, 0) } alias := meshNode.Alias diff --git a/pkg/api/types.go b/pkg/api/types.go index 6439814..0141dcb 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -1,8 +1,8 @@ package api type Route struct { - RouteId string `json:"routeId"` - Prefix string `json:"prefix"` + Prefix string `json:"prefix"` + Path []string `json:"path"` } type SmegNode struct { diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 7baced2..e57a7ed 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -340,7 +340,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { for _, route := range routes { err = routeMap.Map().Set(route.GetDestination().String(), Route{ Destination: route.GetDestination().String(), - HopCount: int64(route.GetHopCount()), + Path: route.GetPath(), }) if err != nil { @@ -372,6 +372,7 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { } routes, err := automerge.As[map[string]Route](routeMap) + return lib.MapValues(routes), err } @@ -398,10 +399,10 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e for _, route := range nodeRoutes { otherRoute, ok := routes[route.GetDestination().String()] - if !ok || route.GetHopCount() < otherRoute.GetHopCount() { + if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() { routes[route.GetDestination().String()] = &Route{ Destination: route.GetDestination().String(), - HopCount: int64(route.GetHopCount()) + 1, + Path: append(route.Path, m.GetMeshId()), } } } @@ -524,7 +525,10 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route { return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route { - return &r + return &Route{ + Destination: r.Destination, + Path: r.Path, + } }) } @@ -588,5 +592,9 @@ func (r *Route) GetDestination() *net.IPNet { } func (r *Route) GetHopCount() int { - return int(r.HopCount) + return len(r.Path) +} + +func (r *Route) GetPath() []string { + return r.Path } diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index ae35d25..150ff8d 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -2,8 +2,8 @@ package crdt // Route: Represents a CRDT of the given route type Route struct { - Destination string `automerge:"destination"` - HopCount int64 `automerge:"hopCount"` + Destination string `automerge:"destination"` + Path []string `automerge:"path"` } // MeshNodeCrdt: Represents a CRDT for a mesh nodes diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index 4215ad4..0d03ca6 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -9,6 +9,11 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +type MeshRoute struct { + Destination string + Path []string +} + // Represents a WireGuard MeshNode type MeshNode struct { HostEndpoint string @@ -16,7 +21,7 @@ type MeshNode struct { PublicKey string WgHost string Timestamp int64 - Routes []string + Routes []MeshRoute Description string Alias string Services map[string]string diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index dc73545..28e1b33 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -66,3 +66,13 @@ func Filter[V any](list []V, f filterFunc[V]) []V { return newList } + +func Contains[V any](list []V, proposition func(V) bool) bool { + for _, elem := range list { + if proposition(elem) { + return true + } + } + + return false +} diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index d1c4327..8cda57a 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -108,14 +108,35 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev // getRoutes: finds the routes with the least hop distance. If more than one route exists // consistently hash to evenly spread the distribution of traffic -func (m *WgMeshConfigApplyer) getRoutes(mesh MeshSnapshot) map[string][]routeNode { +func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode { + mesh, _ := meshProvider.GetMesh() + routes := make(map[string][]routeNode) + meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet { + ula := &ip.ULABuilder{} + ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) + + return ipNet + }) + for _, node := range mesh.GetNodes() { - for _, route := range node.GetRoutes() { + pubKey, _ := node.GetPublicKey() + meshRoutes, _ := meshProvider.GetRoutes(pubKey.String()) + + for _, route := range meshRoutes { + if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { + if prefix == nil || route == nil || route.GetDestination() == nil { + return false + } + + return prefix.Contains(route.GetDestination().IP) + }) { + continue + } + destination := route.GetDestination().String() otherRoute, ok := routes[destination] - pubKey, _ := node.GetPublicKey() rn := routeNode{ gateway: pubKey.String(), @@ -126,7 +147,7 @@ func (m *WgMeshConfigApplyer) getRoutes(mesh MeshSnapshot) map[string][]routeNod otherRoute = make([]routeNode, 1) otherRoute[0] = rn routes[destination] = otherRoute - } else if otherRoute[0].route.GetHopCount() > route.GetHopCount() { + } else if route.GetHopCount() < otherRoute[0].route.GetHopCount() { otherRoute[0] = rn } else if otherRoute[0].route.GetHopCount() == route.GetHopCount() { routes[destination] = append(otherRoute, rn) @@ -160,7 +181,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } peerToClients := make(map[string][]net.IPNet) - routes := m.getRoutes(snap) + routes := m.getRoutes(mesh) installedRoutes := make([]lib.Route, 0) for _, n := range nodes { diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 3c54613..8197f9d 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -1,13 +1,9 @@ package mesh import ( - "fmt" - "net" - "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" - "golang.org/x/sys/unix" ) type RouteManager interface { @@ -57,6 +53,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error { err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{ Destination: ipNet, HopCount: 0, + Path: make([]string, 0), })...) if err != nil { @@ -91,110 +88,6 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { return nil } -// AddRoute adds a route to the given interface -func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error { - rtnl, err := lib.NewRtNetlinkConfig() - - if err != nil { - return fmt.Errorf("failed to create config: %w", err) - } - defer rtnl.Close() - - // Delete any routes that may be vacant - err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...) - - if err != nil { - return err - } - - for _, route := range routes { - if route.Destination.String() == meshPrefix { - continue - } - - err = rtnl.AddRoute(ifName, route) - - if err != nil { - return err - } - } - - return nil -} - -func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error { - routeMapFunc := func(route string) lib.Route { - _, cidr, _ := net.ParseCIDR(route) - - r := lib.Route{ - Destination: *cidr, - Gateway: node.GetWgHost().IP, - } - return r - } - - ipBuilder := &ip.ULABuilder{} - ipNet, err := ipBuilder.GetIPNet(meshid) - - if err != nil { - return err - } - - theRoutes := lib.Map(node.GetRoutes(), func(r Route) string { - return r.GetDestination().String() - }) - - routes := lib.Map(append(theRoutes, ipNet.String()), routeMapFunc) - return m.addRoute(ifName, ipNet.String(), routes...) -} - -func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error { - mesh, err := meshProvider.GetMesh() - - if err != nil { - return err - } - - dev, err := meshProvider.GetDevice() - - if err != nil { - return err - } - - self, err := m.meshManager.GetSelf(meshProvider.GetMeshId()) - - if err != nil { - return err - } - - for _, node := range mesh.GetNodes() { - if NodeEquals(self, node) { - continue - } - - err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node) - - if err != nil { - return err - } - } - - return nil -} - -// InstallRoutes installs all routes to the RIB -func (r *RouteManagerImpl) InstallRoutes() error { - for _, mesh := range r.meshManager.GetMeshes() { - err := r.installRoutes(mesh) - - if err != nil { - return err - } - } - - return nil -} - func NewRouteManager(m MeshManager) RouteManager { return &RouteManagerImpl{meshManager: m} } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index f29b03e..8ab2acf 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -15,11 +15,14 @@ type Route interface { GetDestination() *net.IPNet // GetHopCount: get the total hopcount of the prefix GetHopCount() int + // GetPath: get a list of AS paths to get to the destination + GetPath() []string } type RouteStub struct { Destination *net.IPNet HopCount int + Path []string } func (r *RouteStub) GetDestination() *net.IPNet { @@ -30,6 +33,10 @@ func (r *RouteStub) GetHopCount() int { return r.HopCount } +func (r *RouteStub) GetPath() []string { + return r.Path +} + // MeshNode represents an implementation of a node in a mesh type MeshNode interface { // GetHostEndpoint: gets the gRPC endpoint of the node diff --git a/pkg/query/query.go b/pkg/query/query.go index 4884a7d..95fdfbe 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -3,6 +3,7 @@ package query import ( "encoding/json" "fmt" + "strings" "github.com/jmespath/go-jmespath" "github.com/tim-beatham/wgmesh/pkg/conf" @@ -27,6 +28,7 @@ type QueryError struct { type QueryRoute struct { Destination string `json:"destination"` HopCount int `json:"hopCount"` + Path string `json:"path"` } type QueryNode struct { @@ -87,6 +89,7 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode { return QueryRoute{ Destination: r.GetDestination().String(), HopCount: r.GetHopCount(), + Path: strings.Join(r.GetPath(), ","), } }) queryNode.Description = node.GetDescription() diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 51aa823..f6ccdf4 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -152,8 +152,11 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { PublicKey: pubKey.String(), WgHost: node.GetWgHost().String(), Timestamp: node.GetTimeStamp(), - Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) string { - return r.GetDestination().String() + Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute { + return ctrlserver.MeshRoute{ + Destination: r.GetDestination().String(), + Path: r.GetPath(), + } }), Description: node.GetDescription(), Alias: node.GetAlias(),