From 0058c9f4c97dd9fa1ab5513b15c6f74afc9760d4 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Fri, 8 Dec 2023 11:49:24 +0000 Subject: [PATCH 1/2] 47-default-routing Implementing default routing so that all traffic goes out of an exit point. --- pkg/conf/conf.go | 2 ++ pkg/mesh/config.go | 17 +++++++++++------ pkg/mesh/manager.go | 2 +- pkg/mesh/route.go | 24 ++++++++++++++++++++---- pkg/mesh/types.go | 2 +- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 8ae48c3..9063485 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -47,6 +47,8 @@ type WgMeshConfiguration struct { IPDiscovery IPDiscovery `yaml:"ipDiscovery"` // AdvertiseRoutes advertises other meshes if the node is in multiple meshes AdvertiseRoutes bool `yaml:"advertiseRoutes"` + // AdvertiseDefaultRoute advertises a default route out of the mesh. + AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"` // Endpoint is the IP in which this computer is publicly reachable. // usecase is when the node has multiple IP addresses Endpoint string `yaml:"publicEndpoint"` diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index f156259..00a073f 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -116,7 +116,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet { ula := &ip.ULABuilder{} ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) - return ipNet }) @@ -125,6 +124,12 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] for _, route := range node.GetRoutes() { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { + defaultRoute, _, _ := net.ParseCIDR("::/0") + + if prefix.IP.Equal(defaultRoute) && m.config.AdvertiseDefaultRoute { + return true + } + return prefix.Contains(route.GetDestination().IP) }) { continue @@ -168,6 +173,10 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) { self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + routes := lib.Map(lib.MapKeys(m.getRoutes(mesh)), func(destination string) net.IPNet { + _, ipNet, _ := net.ParseCIDR(destination) + return *ipNet + }) if err != nil { return nil, err @@ -184,17 +193,13 @@ func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNod return nil, err } - allowedips := make([]net.IPNet, 1) - _, ipnet, _ := net.ParseCIDR("::/0") - allowedips[0] = *ipnet - peerCfgs := make([]wgtypes.PeerConfig, 1) peerCfgs[0] = wgtypes.PeerConfig{ PublicKey: pubKey, Endpoint: endpoint, PersistentKeepaliveInterval: &keepAlive, - AllowedIPs: allowedips, + AllowedIPs: routes, } cfg := wgtypes.Config{ diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 7576e87..60a142d 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -471,7 +471,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager { m.RouteManager = params.RouteManager if m.RouteManager == nil { - m.RouteManager = NewRouteManager(m) + m.RouteManager = NewRouteManager(m, ¶ms.Conf) } m.idGenerator = params.IdGenerator diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 8197f9d..1a43a6c 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -1,6 +1,9 @@ package mesh import ( + "net" + + "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" @@ -13,6 +16,7 @@ type RouteManager interface { type RouteManagerImpl struct { meshManager MeshManager + conf *conf.WgMeshConfiguration } func (r *RouteManagerImpl) UpdateRoutes() error { @@ -32,12 +36,22 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - routes, err := mesh1.GetRoutes(pubKey.String()) + routeMap, err := mesh1.GetRoutes(pubKey.String()) if err != nil { return err } + if r.conf.AdvertiseDefaultRoute { + _, defaultRoute, _ := net.ParseCIDR("::/0") + + mesh1.AddRoutes(NodeID(self), &RouteStub{ + Destination: defaultRoute, + HopCount: 0, + Path: make([]string, 0), + }) + } + for _, mesh2 := range meshes { if mesh1 == mesh2 { continue @@ -50,7 +64,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{ + routes := lib.MapValues(routeMap) + + err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{ Destination: ipNet, HopCount: 0, Path: make([]string, 0), @@ -88,6 +104,6 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error { return nil } -func NewRouteManager(m MeshManager) RouteManager { - return &RouteManagerImpl{meshManager: m} +func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager { + return &RouteManagerImpl{meshManager: m, conf: conf} } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index d8e9eb2..72556d5 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -173,7 +173,7 @@ type MeshProviderFactory interface { // MeshNodeFactoryParams are the parameters required to construct // a mesh node type MeshNodeFactoryParams struct { -PublicKey *wgtypes.Key + PublicKey *wgtypes.Key NodeIP net.IP WgPort int Endpoint string From 815c4484ee81c75afac9f4167aceae77077eb4c2 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Fri, 8 Dec 2023 20:02:57 +0000 Subject: [PATCH 2/2] 47-default-routing Implemented default routing and improved size of gossip. Using 64 bit hash funciton to identify vector. --- pkg/crdt/datastore.go | 12 +++++-- pkg/crdt/factory.go | 4 +-- pkg/crdt/g_map.go | 38 +++++++++++--------- pkg/crdt/two_phase_map.go | 61 ++++++++++++++++++-------------- pkg/crdt/two_phase_map_syncer.go | 11 +++--- pkg/crdt/vector_clock.go | 40 +++++++++++---------- pkg/lib/rtnetlink.go | 10 +++++- pkg/mesh/config.go | 27 +++++++++++--- pkg/mesh/route.go | 13 +++---- pkg/route/route.go | 6 +++- pkg/sync/syncer.go | 2 +- pkg/sync/syncscheduler.go | 3 +- 12 files changed, 141 insertions(+), 86 deletions(-) diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index 638af18..b3b94ca 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -179,8 +179,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) { // GetMesh() returns a snapshot of the mesh provided by the mesh provider. func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) { + nodes := m.store.AsList() + + snapshot := make(map[string]MeshNode) + + for _, node := range nodes { + snapshot[node.PublicKey] = node + } + return &MeshSnapshot{ - Nodes: m.store.AsMap(), + Nodes: snapshot, }, nil } @@ -408,7 +416,7 @@ func (m *TwoPhaseStoreMeshManager) Prune() error { // GetPeers: get a list of contactable peers func (m *TwoPhaseStoreMeshManager) GetPeers() []string { - nodes := lib.MapValues(m.store.AsMap()) + nodes := m.store.AsList() nodes = lib.Filter(nodes, func(mn MeshNode) bool { if mn.Type != string(conf.PEER_ROLE) { return false diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go index 571d430..4895bbf 100644 --- a/pkg/crdt/factory.go +++ b/pkg/crdt/factory.go @@ -18,9 +18,9 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) Client: params.Client, conf: params.Conf, store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { - h := fnv.New32a() + h := fnv.New64a() h.Write([]byte(s)) - return uint64(h.Sum32()) + return h.Sum64() }, uint64(3*params.Conf.KeepAliveTime)), }, nil } diff --git a/pkg/crdt/g_map.go b/pkg/crdt/g_map.go index 1d5bab4..791c8e9 100644 --- a/pkg/crdt/g_map.go +++ b/pkg/crdt/g_map.go @@ -15,7 +15,7 @@ type Bucket[D any] struct { // GMap is a set that can only grow in size type GMap[K cmp.Ordered, D any] struct { lock sync.RWMutex - contents map[K]Bucket[D] + contents map[uint64]Bucket[D] clock *VectorClock[K] } @@ -24,7 +24,7 @@ func (g *GMap[K, D]) Put(key K, value D) { clock := g.clock.IncrementClock() - g.contents[key] = Bucket[D]{ + g.contents[g.clock.hashFunc(key)] = Bucket[D]{ Vector: clock, Contents: value, } @@ -33,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) { } func (g *GMap[K, D]) Contains(key K) bool { + return g.contains(g.clock.hashFunc(key)) +} + +func (g *GMap[K, D]) contains(key uint64) bool { g.lock.RLock() _, ok := g.contents[key] @@ -42,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool { return ok } -func (g *GMap[K, D]) put(key K, b Bucket[D]) { +func (g *GMap[K, D]) put(key uint64, b Bucket[D]) { g.lock.Lock() if g.contents[key].Vector < b.Vector { @@ -52,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) { g.lock.Unlock() } -func (g *GMap[K, D]) get(key K) Bucket[D] { +func (g *GMap[K, D]) get(key uint64) Bucket[D] { g.lock.RLock() bucket := g.contents[key] g.lock.RUnlock() @@ -61,14 +65,14 @@ func (g *GMap[K, D]) get(key K) Bucket[D] { } func (g *GMap[K, D]) Get(key K) D { - return g.get(key).Contents + return g.get(g.clock.hashFunc(key)).Contents } func (g *GMap[K, D]) Mark(key K) { g.lock.Lock() - bucket := g.contents[key] + bucket := g.contents[g.clock.hashFunc(key)] bucket.Gravestone = true - g.contents[key] = bucket + g.contents[g.clock.hashFunc(key)] = bucket g.lock.Unlock() } @@ -78,7 +82,7 @@ func (g *GMap[K, D]) IsMarked(key K) bool { g.lock.RLock() - bucket, ok := g.contents[key] + bucket, ok := g.contents[g.clock.hashFunc(key)] if ok { marked = bucket.Gravestone @@ -89,10 +93,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool { return marked } -func (g *GMap[K, D]) Keys() []K { +func (g *GMap[K, D]) Keys() []uint64 { g.lock.RLock() - contents := make([]K, len(g.contents)) + contents := make([]uint64, len(g.contents)) index := 0 for key := range g.contents { @@ -104,8 +108,8 @@ func (g *GMap[K, D]) Keys() []K { return contents } -func (g *GMap[K, D]) Save() map[K]Bucket[D] { - buckets := make(map[K]Bucket[D]) +func (g *GMap[K, D]) Save() map[uint64]Bucket[D] { + buckets := make(map[uint64]Bucket[D]) g.lock.RLock() for key, value := range g.contents { @@ -116,8 +120,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] { return buckets } -func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] { - buckets := make(map[K]Bucket[D]) +func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] { + buckets := make(map[uint64]Bucket[D]) g.lock.RLock() for _, key := range keys { @@ -128,8 +132,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] { return buckets } -func (g *GMap[K, D]) GetClock() map[K]uint64 { - clock := make(map[K]uint64) +func (g *GMap[K, D]) GetClock() map[uint64]uint64 { + clock := make(map[uint64]uint64) g.lock.RLock() for key, bucket := range g.contents { @@ -166,7 +170,7 @@ func (g *GMap[K, D]) Prune() { func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] { return &GMap[K, D]{ - contents: make(map[K]Bucket[D]), + contents: make(map[uint64]Bucket[D]), clock: clock, } } diff --git a/pkg/crdt/two_phase_map.go b/pkg/crdt/two_phase_map.go index 452ba81..f735244 100644 --- a/pkg/crdt/two_phase_map.go +++ b/pkg/crdt/two_phase_map.go @@ -14,19 +14,24 @@ type TwoPhaseMap[K cmp.Ordered, D any] struct { } type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct { - Add map[K]Bucket[D] - Remove map[K]Bucket[bool] + Add map[uint64]Bucket[D] + Remove map[uint64]Bucket[bool] } // Contains checks whether the value exists in the map func (m *TwoPhaseMap[K, D]) Contains(key K) bool { - if !m.addMap.Contains(key) { + return m.contains(m.Clock.hashFunc(key)) +} + +// Contains checks whether the value exists in the map +func (m *TwoPhaseMap[K, D]) contains(key uint64) bool { + if !m.addMap.contains(key) { return false } addValue := m.addMap.get(key) - if !m.removeMap.Contains(key) { + if !m.removeMap.contains(key) { return true } @@ -45,6 +50,16 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D { return m.addMap.Get(key) } +func (m *TwoPhaseMap[K, D]) get(key uint64) D { + var result D + + if !m.contains(key) { + return result + } + + return m.addMap.get(key).Contents +} + // Put places the key K in the map func (m *TwoPhaseMap[K, D]) Put(key K, data D) { msgSequence := m.Clock.IncrementClock() @@ -61,13 +76,13 @@ func (m *TwoPhaseMap[K, D]) Remove(key K) { m.removeMap.Put(key, true) } -func (m *TwoPhaseMap[K, D]) Keys() []K { - keys := make([]K, 0) +func (m *TwoPhaseMap[K, D]) keys() []uint64 { + keys := make([]uint64, 0) addKeys := m.addMap.Keys() for _, key := range addKeys { - if !m.Contains(key) { + if !m.contains(key) { continue } @@ -77,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K { return keys } -func (m *TwoPhaseMap[K, D]) AsMap() map[K]D { - theMap := make(map[K]D) +func (m *TwoPhaseMap[K, D]) AsList() []D { + theList := make([]D, 0) - keys := m.Keys() + keys := m.keys() for _, key := range keys { - theMap[key] = m.Get(key) + theList = append(theList, m.get(key)) } - return theMap + return theList } func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] { @@ -107,9 +122,9 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh } type TwoPhaseMapState[K cmp.Ordered] struct { - Vectors map[K]uint64 - AddContents map[K]uint64 - RemoveContents map[K]uint64 + Vectors map[uint64]uint64 + AddContents map[uint64]uint64 + RemoveContents map[uint64]uint64 } func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool { @@ -120,7 +135,7 @@ func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool { // Sums the current values of the vectors. Provides good approximation // of increasing numbers func (m *TwoPhaseMap[K, D]) GetHash() uint64 { - return m.addMap.GetHash() + m.removeMap.GetHash() + return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1) } // GetState: get the current vector clock of the add and remove @@ -136,16 +151,10 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] { } } -func (m *TwoPhaseMap[K, D]) UpdateVector(state *TwoPhaseMapState[K]) { - for key, value := range state.Vectors { - m.Clock.Put(key, value) - } -} - func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] { mapState := &TwoPhaseMapState[K]{ - AddContents: make(map[K]uint64), - RemoveContents: make(map[K]uint64), + AddContents: make(map[uint64]uint64), + RemoveContents: make(map[uint64]uint64), } for key, value := range state.AddContents { @@ -172,12 +181,12 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { // Gravestone is local only to that node. // Discover ourselves if the node is alive m.addMap.put(key, value) - m.Clock.Put(key, value.Vector) + m.Clock.put(key, value.Vector) } for key, value := range snapshot.Remove { m.removeMap.put(key, value) - m.Clock.Put(key, value.Vector) + m.Clock.put(key, value.Vector) } } diff --git a/pkg/crdt/two_phase_map_syncer.go b/pkg/crdt/two_phase_map_syncer.go index 41e372f..1247823 100644 --- a/pkg/crdt/two_phase_map_syncer.go +++ b/pkg/crdt/two_phase_map_syncer.go @@ -44,7 +44,7 @@ func hash(syncer *TwoPhaseSyncer) ([]byte, bool) { err := enc.Encode(hash) if err != nil { - logging.Log.WriteInfof(err.Error()) + logging.Log.WriteErrorf(err.Error()) } syncer.IncrementState() @@ -59,7 +59,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) { err := dec.Decode(&hash) if err != nil { - logging.Log.WriteInfof(err.Error()) + logging.Log.WriteErrorf(err.Error()) } // If vector clocks are equal then no need to merge state @@ -74,7 +74,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) { err = enc.Encode(*syncer.mapState) if err != nil { - logging.Log.WriteInfof(err.Error()) + logging.Log.WriteErrorf(err.Error()) } syncer.IncrementState() @@ -93,10 +93,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) { err := dec.Decode(&mapState) if err != nil { - logging.Log.WriteInfof(err.Error()) + logging.Log.WriteErrorf(err.Error()) } difference := syncer.mapState.Difference(&mapState) + syncer.manager.store.Clock.Merge(mapState.Vectors) var sendBuffer bytes.Buffer enc := gob.NewEncoder(&sendBuffer) @@ -163,7 +164,7 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error { func (t *TwoPhaseSyncer) Complete() { logging.Log.WriteInfof("SYNC COMPLETED") - if t.state == FINISHED || t.state == MERGE { + if t.state >= MERGE { t.manager.store.Clock.IncrementClock() } } diff --git a/pkg/crdt/vector_clock.go b/pkg/crdt/vector_clock.go index 0439efa..78882c0 100644 --- a/pkg/crdt/vector_clock.go +++ b/pkg/crdt/vector_clock.go @@ -2,7 +2,6 @@ package crdt import ( "cmp" - "slices" "sync" "time" @@ -19,7 +18,7 @@ type VectorBucket struct { // Vector clock defines an abstract data type // for a vector clock implementation type VectorClock[K cmp.Ordered] struct { - vectors map[K]*VectorBucket + vectors map[uint64]*VectorBucket lock sync.RWMutex processID K staleTime uint64 @@ -40,7 +39,7 @@ func (m *VectorClock[K]) IncrementClock() uint64 { lastUpdate: uint64(time.Now().Unix()), } - m.vectors[m.processID] = &newBucket + m.vectors[m.hashFunc(m.processID)] = &newBucket m.lock.Unlock() return maxClock @@ -53,26 +52,28 @@ func (m *VectorClock[K]) GetHash() uint64 { hash := uint64(0) - sortedKeys := lib.MapKeys(m.vectors) - slices.Sort(sortedKeys) - for key, bucket := range m.vectors { - hash += m.hashFunc(key) - hash += bucket.clock + hash += key * (bucket.clock + 1) } m.lock.RUnlock() return hash } +func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) { + for key, value := range vectors { + m.put(key, value) + } +} + // getStale: get all entries that are stale within the mesh -func (m *VectorClock[K]) getStale() []K { +func (m *VectorClock[K]) getStale() []uint64 { m.lock.RLock() maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 { return max(i, vb.lastUpdate) }) - toRemove := make([]K, 0) + toRemove := make([]uint64, 0) for key, bucket := range m.vectors { if maxTimeStamp-bucket.lastUpdate > m.staleTime { @@ -97,10 +98,14 @@ func (m *VectorClock[K]) Prune() { } func (m *VectorClock[K]) GetTimestamp(processId K) uint64 { - return m.vectors[processId].lastUpdate + return m.vectors[m.hashFunc(m.processID)].lastUpdate } func (m *VectorClock[K]) Put(key K, value uint64) { + m.put(m.hashFunc(key), value) +} + +func (m *VectorClock[K]) put(key uint64, value uint64) { clockValue := uint64(0) m.lock.Lock() @@ -121,16 +126,13 @@ func (m *VectorClock[K]) Put(key K, value uint64) { m.lock.Unlock() } -func (m *VectorClock[K]) GetClock() map[K]uint64 { - clock := make(map[K]uint64) +func (m *VectorClock[K]) GetClock() map[uint64]uint64 { + clock := make(map[uint64]uint64) m.lock.RLock() - keys := lib.MapKeys(m.vectors) - slices.Sort(keys) - - for key, value := range clock { - clock[key] = value + for key, value := range m.vectors { + clock[key] = value.clock } m.lock.RUnlock() @@ -139,7 +141,7 @@ func (m *VectorClock[K]) GetClock() map[K]uint64 { func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] { return &VectorClock[K]{ - vectors: make(map[K]*VectorBucket), + vectors: make(map[uint64]*VectorBucket), processID: processID, staleTime: staleTime, hashFunc: hashFunc, diff --git a/pkg/lib/rtnetlink.go b/pkg/lib/rtnetlink.go index d26905a..f95b54b 100644 --- a/pkg/lib/rtnetlink.go +++ b/pkg/lib/rtnetlink.go @@ -248,6 +248,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R if route.equal(r) { return false } + + if family == unix.AF_INET && route.Destination.IP.To4() == nil { + return false + } + + if family == unix.AF_INET6 && route.Destination.IP.To16() == nil { + return false + } } return true } @@ -255,7 +263,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R toDelete := Filter(ifRoutes, shouldExclude) for _, route := range toDelete { - logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String()) + logging.Log.WriteInfof("Deleting route: %s", route.Destination.String()) err := c.DeleteRoute(ifName, route) if err != nil { diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 00a073f..ddd01ac 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -124,9 +124,10 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] for _, route := range node.GetRoutes() { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { - defaultRoute, _, _ := net.ParseCIDR("::/0") + v6Default, _, _ := net.ParseCIDR("::/0") + v4Default, _, _ := net.ParseCIDR("0.0.0.0/0") - if prefix.IP.Equal(defaultRoute) && m.config.AdvertiseDefaultRoute { + if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute { return true } @@ -171,12 +172,16 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh return peer } -func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) { +func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) { self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + ula := &ip.ULABuilder{} + meshNet, _ := ula.GetIPNet(mesh.GetMeshId()) + routes := lib.Map(lib.MapKeys(m.getRoutes(mesh)), func(destination string) net.IPNet { _, ipNet, _ := net.ParseCIDR(destination) return *ipNet }) + routes = append(routes, *meshNet) if err != nil { return nil, err @@ -202,10 +207,20 @@ func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNod AllowedIPs: routes, } + installedRoutes := make([]lib.Route, 0) + + for _, route := range peerCfgs[0].AllowedIPs { + installedRoutes = append(installedRoutes, lib.Route{ + Gateway: peer.GetWgHost().IP, + Destination: route, + }) + } + cfg := wgtypes.Config{ Peers: peerCfgs, } + m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...) return &cfg, err } @@ -260,7 +275,9 @@ func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, ula := &ip.ULABuilder{} ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) - if !ipNet.Contains(route.IP) { + _, defaultRoute, _ := net.ParseCIDR("::/0") + + if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) { installedRoutes = append(installedRoutes, lib.Route{ Gateway: n.GetWgHost().IP, Destination: route, @@ -314,7 +331,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { case conf.PEER_ROLE: cfg, err = m.getPeerConfig(mesh, peers, clients, dev) case conf.CLIENT_ROLE: - cfg, err = m.getClientConfig(mesh, peers, clients) + cfg, err = m.getClientConfig(mesh, peers, clients, dev) } if err != nil { diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 1a43a6c..2354367 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -43,13 +43,14 @@ func (r *RouteManagerImpl) UpdateRoutes() error { } if r.conf.AdvertiseDefaultRoute { - _, defaultRoute, _ := net.ParseCIDR("::/0") + _, ipv6Default, _ := net.ParseCIDR("::/0") - mesh1.AddRoutes(NodeID(self), &RouteStub{ - Destination: defaultRoute, - HopCount: 0, - Path: make([]string, 0), - }) + mesh1.AddRoutes(NodeID(self), + &RouteStub{ + Destination: ipv6Default, + HopCount: 0, + Path: make([]string, 0), + }) } for _, mesh2 := range meshes { diff --git a/pkg/route/route.go b/pkg/route/route.go index 11de7d7..976b6c4 100644 --- a/pkg/route/route.go +++ b/pkg/route/route.go @@ -19,7 +19,11 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) return err } - err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...) + ip6Routes := lib.Filter(routes, func(r lib.Route) bool { + return r.Destination.IP.To4() == nil + }) + + err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...) if err != nil { return err diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 075181d..95efe39 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -123,7 +123,7 @@ func (s *SyncerImpl) SyncMeshes() error { err := s.Sync(meshId) if err != nil { - return err + logging.Log.WriteErrorf(err.Error()) } } diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 61bdf41..4be4e30 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -8,7 +8,8 @@ import ( // Run implements SyncScheduler. func syncFunction(syncer Syncer) lib.TimerFunc { return func() error { - return syncer.SyncMeshes() + syncer.SyncMeshes() + return nil } }