diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index cca0a35..90a3701 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -37,7 +37,7 @@ type WgConfiguration struct { // service for IPDiscoverability IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes - AdvertiseRoutes *bool `yaml:"advertiseRoutes"` + AdvertiseRoutes *bool `yaml:"advertiseRoute"` // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route // for all nodes to route their packets to AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"` diff --git a/pkg/lib/rtnetlink.go b/pkg/lib/rtnetlink.go index 3daa5ef..d6d4c8d 100644 --- a/pkg/lib/rtnetlink.go +++ b/pkg/lib/rtnetlink.go @@ -225,8 +225,11 @@ type Route struct { } func (r1 Route) equal(r2 Route) bool { + mask1Ones, _ := r1.Destination.Mask.Size() + mask2Ones, _ := r2.Destination.Mask.Size() + return r1.Gateway.String() == r2.Gateway.String() && - r1.Destination.String() == r2.Destination.String() + (mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP)) } // DeleteRoutes deletes all routes not in exclude @@ -257,18 +260,11 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R shouldExclude := func(r Route) bool { for _, route := range exclude { - if route.equal(r) { - return false - } - - if family == unix.AF_INET && route.Destination.IP.To4() == nil { - return false - } - - if family == unix.AF_INET6 && route.Destination.IP.To16() == nil { + if r.equal(route) { return false } } + return true } diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 6a24b25..eb4db01 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -134,10 +134,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] for _, route := range node.GetRoutes() { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { - v6Default, _, _ := net.ParseCIDR("::/0") - v4Default, _, _ := net.ParseCIDR("0.0.0.0/0") - - if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute { + if prefix.IP.Equal(net.IPv6zero) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute { return true } diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 26d8a9d..9809e94 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -510,7 +510,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager { m.RouteManager = params.RouteManager if m.RouteManager == nil { - m.RouteManager = NewRouteManager(m, ¶ms.Conf) + m.RouteManager = NewRouteManager(m) } if params.CommandRunner == nil { diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 39ef2bc..9d39fb4 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -3,7 +3,6 @@ package mesh import ( "net" - "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" ) @@ -14,7 +13,6 @@ type RouteManager interface { type RouteManagerImpl struct { meshManager MeshManager - conf *conf.DaemonConfiguration } func (r *RouteManagerImpl) UpdateRoutes() error { @@ -32,23 +30,25 @@ func (r *RouteManagerImpl) UpdateRoutes() error { routes[mesh1.GetMeshId()] = make([]Route, 0) } + if *mesh1.GetConfiguration().AdvertiseDefaultRoute { + _, ipv6Default, _ := net.ParseCIDR("::/0") + + defaultRoute := &RouteStub{ + Destination: ipv6Default, + HopCount: 0, + Path: []string{mesh1.GetMeshId()}, + } + + mesh1.AddRoutes(NodeID(self), defaultRoute) + routes[mesh1.GetMeshId()] = append(routes[mesh1.GetMeshId()], defaultRoute) + } + routeMap, err := mesh1.GetRoutes(NodeID(self)) if err != nil { return err } - if *r.conf.BaseConfiguration.AdvertiseDefaultRoute { - _, ipv6Default, _ := net.ParseCIDR("::/0") - - mesh1.AddRoutes(NodeID(self), - &RouteStub{ - Destination: ipv6Default, - HopCount: 0, - Path: make([]string, 0), - }) - } - for _, mesh2 := range meshes { routeValues, ok := routes[mesh2.GetMeshId()] @@ -75,8 +75,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return s == mesh2.GetMeshId() } - // Ensure that the route does not see it's own IP - return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh) + // Remove any potential routing loops + return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && + !lib.Contains(r.GetPath()[1:], pathNotMesh) }) routes[mesh2.GetMeshId()] = routeValues @@ -106,6 +107,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return nil } -func NewRouteManager(m MeshManager, conf *conf.DaemonConfiguration) RouteManager { - return &RouteManagerImpl{meshManager: m, conf: conf} +func NewRouteManager(m MeshManager) RouteManager { + return &RouteManagerImpl{meshManager: m} }