From 0058c9f4c97dd9fa1ab5513b15c6f74afc9760d4 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Fri, 8 Dec 2023 11:49:24 +0000 Subject: [PATCH] 47-default-routing Implementing default routing so that all traffic goes out of an exit point. --- pkg/conf/conf.go | 2 ++ pkg/mesh/config.go | 17 +++++++++++------ pkg/mesh/manager.go | 2 +- pkg/mesh/route.go | 24 ++++++++++++++++++++---- pkg/mesh/types.go | 2 +- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 8ae48c3..9063485 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -47,6 +47,8 @@ type WgMeshConfiguration struct { IPDiscovery IPDiscovery `yaml:"ipDiscovery"` // AdvertiseRoutes advertises other meshes if the node is in multiple meshes AdvertiseRoutes bool `yaml:"advertiseRoutes"` + // AdvertiseDefaultRoute advertises a default route out of the mesh. + AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"` // Endpoint is the IP in which this computer is publicly reachable. // usecase is when the node has multiple IP addresses Endpoint string `yaml:"publicEndpoint"` diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index f156259..00a073f 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -116,7 +116,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet { ula := &ip.ULABuilder{} ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) - return ipNet }) @@ -125,6 +124,12 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] for _, route := range node.GetRoutes() { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { + defaultRoute, _, _ := net.ParseCIDR("::/0") + + if prefix.IP.Equal(defaultRoute) && m.config.AdvertiseDefaultRoute { + return true + } + return prefix.Contains(route.GetDestination().IP) }) { continue @@ -168,6 +173,10 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) { self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + routes := lib.Map(lib.MapKeys(m.getRoutes(mesh)), func(destination string) net.IPNet { + _, ipNet, _ := net.ParseCIDR(destination) + return *ipNet + }) if err != nil { return nil, err @@ -184,17 +193,13 @@ func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNod return nil, err } - allowedips := make([]net.IPNet, 1) - _, ipnet, _ := net.ParseCIDR("::/0") - allowedips[0] = *ipnet - peerCfgs := make([]wgtypes.PeerConfig, 1) peerCfgs[0] = wgtypes.PeerConfig{ PublicKey: pubKey, Endpoint: endpoint, PersistentKeepaliveInterval: &keepAlive, - AllowedIPs: allowedips, + AllowedIPs: routes, } cfg := wgtypes.Config{ diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 7576e87..60a142d 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -471,7 +471,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager { m.RouteManager = params.RouteManager if m.RouteManager == nil { - m.RouteManager = NewRouteManager(m) + m.RouteManager = NewRouteManager(m, ¶ms.Conf) } m.idGenerator = params.IdGenerator diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 8197f9d..1a43a6c 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -1,6 +1,9 @@ package mesh import ( + "net" + + "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" @@ -13,6 +16,7 @@ type RouteManager interface { type RouteManagerImpl struct { meshManager MeshManager + conf *conf.WgMeshConfiguration } func (r *RouteManagerImpl) UpdateRoutes() error { @@ -32,12 +36,22 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - routes, err := mesh1.GetRoutes(pubKey.String()) + routeMap, err := mesh1.GetRoutes(pubKey.String()) if err != nil { return err } + if r.conf.AdvertiseDefaultRoute { + _, defaultRoute, _ := net.ParseCIDR("::/0") + + mesh1.AddRoutes(NodeID(self), &RouteStub{ + Destination: defaultRoute, + HopCount: 0, + Path: make([]string, 0), + }) + } + for _, mesh2 := range meshes { if mesh1 == mesh2 { continue @@ -50,7 +64,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{ + routes := lib.MapValues(routeMap) + + err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{ Destination: ipNet, HopCount: 0, Path: make([]string, 0), @@ -88,6 +104,6 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { return nil } -func NewRouteManager(m MeshManager) RouteManager { - return &RouteManagerImpl{meshManager: m} +func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager { + return &RouteManagerImpl{meshManager: m, conf: conf} } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index d8e9eb2..72556d5 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -173,7 +173,7 @@ type MeshProviderFactory interface { // MeshNodeFactoryParams are the parameters required to construct // a mesh node type MeshNodeFactoryParams struct { -PublicKey *wgtypes.Key + PublicKey *wgtypes.Key NodeIP net.IP WgPort int Endpoint string