forked from extern/smegmesh
Merge pull request #48 from tim-beatham/47-default-routing
47 default routing
This commit is contained in:
commit
52feb5767b
@ -47,6 +47,8 @@ type WgMeshConfiguration struct {
|
||||
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
|
||||
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
|
||||
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
|
||||
// AdvertiseDefaultRoute advertises a default route out of the mesh.
|
||||
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
|
||||
// Endpoint is the IP in which this computer is publicly reachable.
|
||||
// usecase is when the node has multiple IP addresses
|
||||
Endpoint string `yaml:"publicEndpoint"`
|
||||
|
@ -179,8 +179,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
|
||||
|
||||
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
||||
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||
nodes := m.store.AsList()
|
||||
|
||||
snapshot := make(map[string]MeshNode)
|
||||
|
||||
for _, node := range nodes {
|
||||
snapshot[node.PublicKey] = node
|
||||
}
|
||||
|
||||
return &MeshSnapshot{
|
||||
Nodes: m.store.AsMap(),
|
||||
Nodes: snapshot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -408,7 +416,7 @@ func (m *TwoPhaseStoreMeshManager) Prune() error {
|
||||
|
||||
// GetPeers: get a list of contactable peers
|
||||
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
||||
nodes := lib.MapValues(m.store.AsMap())
|
||||
nodes := m.store.AsList()
|
||||
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
||||
if mn.Type != string(conf.PEER_ROLE) {
|
||||
return false
|
||||
|
@ -18,9 +18,9 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
|
||||
Client: params.Client,
|
||||
conf: params.Conf,
|
||||
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
|
||||
h := fnv.New32a()
|
||||
h := fnv.New64a()
|
||||
h.Write([]byte(s))
|
||||
return uint64(h.Sum32())
|
||||
return h.Sum64()
|
||||
}, uint64(3*params.Conf.KeepAliveTime)),
|
||||
}, nil
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ type Bucket[D any] struct {
|
||||
// GMap is a set that can only grow in size
|
||||
type GMap[K cmp.Ordered, D any] struct {
|
||||
lock sync.RWMutex
|
||||
contents map[K]Bucket[D]
|
||||
contents map[uint64]Bucket[D]
|
||||
clock *VectorClock[K]
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ func (g *GMap[K, D]) Put(key K, value D) {
|
||||
|
||||
clock := g.clock.IncrementClock()
|
||||
|
||||
g.contents[key] = Bucket[D]{
|
||||
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
|
||||
Vector: clock,
|
||||
Contents: value,
|
||||
}
|
||||
@ -33,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) {
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Contains(key K) bool {
|
||||
return g.contains(g.clock.hashFunc(key))
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) contains(key uint64) bool {
|
||||
g.lock.RLock()
|
||||
|
||||
_, ok := g.contents[key]
|
||||
@ -42,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
||||
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
|
||||
g.lock.Lock()
|
||||
|
||||
if g.contents[key].Vector < b.Vector {
|
||||
@ -52,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
||||
g.lock.Unlock()
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) get(key K) Bucket[D] {
|
||||
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
|
||||
g.lock.RLock()
|
||||
bucket := g.contents[key]
|
||||
g.lock.RUnlock()
|
||||
@ -61,14 +65,14 @@ func (g *GMap[K, D]) get(key K) Bucket[D] {
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Get(key K) D {
|
||||
return g.get(key).Contents
|
||||
return g.get(g.clock.hashFunc(key)).Contents
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Mark(key K) {
|
||||
g.lock.Lock()
|
||||
bucket := g.contents[key]
|
||||
bucket := g.contents[g.clock.hashFunc(key)]
|
||||
bucket.Gravestone = true
|
||||
g.contents[key] = bucket
|
||||
g.contents[g.clock.hashFunc(key)] = bucket
|
||||
g.lock.Unlock()
|
||||
}
|
||||
|
||||
@ -78,7 +82,7 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
|
||||
|
||||
g.lock.RLock()
|
||||
|
||||
bucket, ok := g.contents[key]
|
||||
bucket, ok := g.contents[g.clock.hashFunc(key)]
|
||||
|
||||
if ok {
|
||||
marked = bucket.Gravestone
|
||||
@ -89,10 +93,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
|
||||
return marked
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Keys() []K {
|
||||
func (g *GMap[K, D]) Keys() []uint64 {
|
||||
g.lock.RLock()
|
||||
|
||||
contents := make([]K, len(g.contents))
|
||||
contents := make([]uint64, len(g.contents))
|
||||
index := 0
|
||||
|
||||
for key := range g.contents {
|
||||
@ -104,8 +108,8 @@ func (g *GMap[K, D]) Keys() []K {
|
||||
return contents
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
||||
buckets := make(map[K]Bucket[D])
|
||||
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
|
||||
buckets := make(map[uint64]Bucket[D])
|
||||
g.lock.RLock()
|
||||
|
||||
for key, value := range g.contents {
|
||||
@ -116,8 +120,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
||||
return buckets
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
||||
buckets := make(map[K]Bucket[D])
|
||||
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
|
||||
buckets := make(map[uint64]Bucket[D])
|
||||
g.lock.RLock()
|
||||
|
||||
for _, key := range keys {
|
||||
@ -128,8 +132,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
||||
return buckets
|
||||
}
|
||||
|
||||
func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
||||
clock := make(map[K]uint64)
|
||||
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
|
||||
clock := make(map[uint64]uint64)
|
||||
g.lock.RLock()
|
||||
|
||||
for key, bucket := range g.contents {
|
||||
@ -166,7 +170,7 @@ func (g *GMap[K, D]) Prune() {
|
||||
|
||||
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
|
||||
return &GMap[K, D]{
|
||||
contents: make(map[K]Bucket[D]),
|
||||
contents: make(map[uint64]Bucket[D]),
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
@ -14,19 +14,24 @@ type TwoPhaseMap[K cmp.Ordered, D any] struct {
|
||||
}
|
||||
|
||||
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
|
||||
Add map[K]Bucket[D]
|
||||
Remove map[K]Bucket[bool]
|
||||
Add map[uint64]Bucket[D]
|
||||
Remove map[uint64]Bucket[bool]
|
||||
}
|
||||
|
||||
// Contains checks whether the value exists in the map
|
||||
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
||||
if !m.addMap.Contains(key) {
|
||||
return m.contains(m.Clock.hashFunc(key))
|
||||
}
|
||||
|
||||
// Contains checks whether the value exists in the map
|
||||
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
|
||||
if !m.addMap.contains(key) {
|
||||
return false
|
||||
}
|
||||
|
||||
addValue := m.addMap.get(key)
|
||||
|
||||
if !m.removeMap.Contains(key) {
|
||||
if !m.removeMap.contains(key) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -45,6 +50,16 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
|
||||
return m.addMap.Get(key)
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
|
||||
var result D
|
||||
|
||||
if !m.contains(key) {
|
||||
return result
|
||||
}
|
||||
|
||||
return m.addMap.get(key).Contents
|
||||
}
|
||||
|
||||
// Put places the key K in the map
|
||||
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||
msgSequence := m.Clock.IncrementClock()
|
||||
@ -61,13 +76,13 @@ func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
||||
m.removeMap.Put(key, true)
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) Keys() []K {
|
||||
keys := make([]K, 0)
|
||||
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
|
||||
keys := make([]uint64, 0)
|
||||
|
||||
addKeys := m.addMap.Keys()
|
||||
|
||||
for _, key := range addKeys {
|
||||
if !m.Contains(key) {
|
||||
if !m.contains(key) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -77,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K {
|
||||
return keys
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
|
||||
theMap := make(map[K]D)
|
||||
func (m *TwoPhaseMap[K, D]) AsList() []D {
|
||||
theList := make([]D, 0)
|
||||
|
||||
keys := m.Keys()
|
||||
keys := m.keys()
|
||||
|
||||
for _, key := range keys {
|
||||
theMap[key] = m.Get(key)
|
||||
theList = append(theList, m.get(key))
|
||||
}
|
||||
|
||||
return theMap
|
||||
return theList
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
||||
@ -107,9 +122,9 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
|
||||
}
|
||||
|
||||
type TwoPhaseMapState[K cmp.Ordered] struct {
|
||||
Vectors map[K]uint64
|
||||
AddContents map[K]uint64
|
||||
RemoveContents map[K]uint64
|
||||
Vectors map[uint64]uint64
|
||||
AddContents map[uint64]uint64
|
||||
RemoveContents map[uint64]uint64
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
||||
@ -120,7 +135,7 @@ func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
||||
// Sums the current values of the vectors. Provides good approximation
|
||||
// of increasing numbers
|
||||
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
||||
return m.addMap.GetHash() + m.removeMap.GetHash()
|
||||
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
|
||||
}
|
||||
|
||||
// GetState: get the current vector clock of the add and remove
|
||||
@ -136,16 +151,10 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMap[K, D]) UpdateVector(state *TwoPhaseMapState[K]) {
|
||||
for key, value := range state.Vectors {
|
||||
m.Clock.Put(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
||||
mapState := &TwoPhaseMapState[K]{
|
||||
AddContents: make(map[K]uint64),
|
||||
RemoveContents: make(map[K]uint64),
|
||||
AddContents: make(map[uint64]uint64),
|
||||
RemoveContents: make(map[uint64]uint64),
|
||||
}
|
||||
|
||||
for key, value := range state.AddContents {
|
||||
@ -172,12 +181,12 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
||||
// Gravestone is local only to that node.
|
||||
// Discover ourselves if the node is alive
|
||||
m.addMap.put(key, value)
|
||||
m.Clock.Put(key, value.Vector)
|
||||
m.Clock.put(key, value.Vector)
|
||||
}
|
||||
|
||||
for key, value := range snapshot.Remove {
|
||||
m.removeMap.put(key, value)
|
||||
m.Clock.Put(key, value.Vector)
|
||||
m.Clock.put(key, value.Vector)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||
err := enc.Encode(hash)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof(err.Error())
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
|
||||
syncer.IncrementState()
|
||||
@ -59,7 +59,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||
err := dec.Decode(&hash)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof(err.Error())
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
|
||||
// If vector clocks are equal then no need to merge state
|
||||
@ -74,7 +74,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||
err = enc.Encode(*syncer.mapState)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof(err.Error())
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
|
||||
syncer.IncrementState()
|
||||
@ -93,10 +93,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||
err := dec.Decode(&mapState)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteInfof(err.Error())
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
|
||||
difference := syncer.mapState.Difference(&mapState)
|
||||
syncer.manager.store.Clock.Merge(mapState.Vectors)
|
||||
|
||||
var sendBuffer bytes.Buffer
|
||||
enc := gob.NewEncoder(&sendBuffer)
|
||||
@ -163,7 +164,7 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
||||
|
||||
func (t *TwoPhaseSyncer) Complete() {
|
||||
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||
if t.state == FINISHED || t.state == MERGE {
|
||||
if t.state >= MERGE {
|
||||
t.manager.store.Clock.IncrementClock()
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package crdt
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -19,7 +18,7 @@ type VectorBucket struct {
|
||||
// Vector clock defines an abstract data type
|
||||
// for a vector clock implementation
|
||||
type VectorClock[K cmp.Ordered] struct {
|
||||
vectors map[K]*VectorBucket
|
||||
vectors map[uint64]*VectorBucket
|
||||
lock sync.RWMutex
|
||||
processID K
|
||||
staleTime uint64
|
||||
@ -40,7 +39,7 @@ func (m *VectorClock[K]) IncrementClock() uint64 {
|
||||
lastUpdate: uint64(time.Now().Unix()),
|
||||
}
|
||||
|
||||
m.vectors[m.processID] = &newBucket
|
||||
m.vectors[m.hashFunc(m.processID)] = &newBucket
|
||||
|
||||
m.lock.Unlock()
|
||||
return maxClock
|
||||
@ -53,26 +52,28 @@ func (m *VectorClock[K]) GetHash() uint64 {
|
||||
|
||||
hash := uint64(0)
|
||||
|
||||
sortedKeys := lib.MapKeys(m.vectors)
|
||||
slices.Sort(sortedKeys)
|
||||
|
||||
for key, bucket := range m.vectors {
|
||||
hash += m.hashFunc(key)
|
||||
hash += bucket.clock
|
||||
hash += key * (bucket.clock + 1)
|
||||
}
|
||||
|
||||
m.lock.RUnlock()
|
||||
return hash
|
||||
}
|
||||
|
||||
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
|
||||
for key, value := range vectors {
|
||||
m.put(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// getStale: get all entries that are stale within the mesh
|
||||
func (m *VectorClock[K]) getStale() []K {
|
||||
func (m *VectorClock[K]) getStale() []uint64 {
|
||||
m.lock.RLock()
|
||||
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
|
||||
return max(i, vb.lastUpdate)
|
||||
})
|
||||
|
||||
toRemove := make([]K, 0)
|
||||
toRemove := make([]uint64, 0)
|
||||
|
||||
for key, bucket := range m.vectors {
|
||||
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
|
||||
@ -97,10 +98,14 @@ func (m *VectorClock[K]) Prune() {
|
||||
}
|
||||
|
||||
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
||||
return m.vectors[processId].lastUpdate
|
||||
return m.vectors[m.hashFunc(m.processID)].lastUpdate
|
||||
}
|
||||
|
||||
func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||
m.put(m.hashFunc(key), value)
|
||||
}
|
||||
|
||||
func (m *VectorClock[K]) put(key uint64, value uint64) {
|
||||
clockValue := uint64(0)
|
||||
|
||||
m.lock.Lock()
|
||||
@ -121,16 +126,13 @@ func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||
m.lock.Unlock()
|
||||
}
|
||||
|
||||
func (m *VectorClock[K]) GetClock() map[K]uint64 {
|
||||
clock := make(map[K]uint64)
|
||||
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
|
||||
clock := make(map[uint64]uint64)
|
||||
|
||||
m.lock.RLock()
|
||||
|
||||
keys := lib.MapKeys(m.vectors)
|
||||
slices.Sort(keys)
|
||||
|
||||
for key, value := range clock {
|
||||
clock[key] = value
|
||||
for key, value := range m.vectors {
|
||||
clock[key] = value.clock
|
||||
}
|
||||
|
||||
m.lock.RUnlock()
|
||||
@ -139,7 +141,7 @@ func (m *VectorClock[K]) GetClock() map[K]uint64 {
|
||||
|
||||
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
|
||||
return &VectorClock[K]{
|
||||
vectors: make(map[K]*VectorBucket),
|
||||
vectors: make(map[uint64]*VectorBucket),
|
||||
processID: processID,
|
||||
staleTime: staleTime,
|
||||
hashFunc: hashFunc,
|
||||
|
@ -248,6 +248,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
||||
if route.equal(r) {
|
||||
return false
|
||||
}
|
||||
|
||||
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@ -255,7 +263,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
||||
toDelete := Filter(ifRoutes, shouldExclude)
|
||||
|
||||
for _, route := range toDelete {
|
||||
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
|
||||
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
|
||||
err := c.DeleteRoute(ifName, route)
|
||||
|
||||
if err != nil {
|
||||
|
@ -116,7 +116,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
|
||||
ula := &ip.ULABuilder{}
|
||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||
|
||||
return ipNet
|
||||
})
|
||||
|
||||
@ -125,6 +124,13 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
||||
|
||||
for _, route := range node.GetRoutes() {
|
||||
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
||||
v6Default, _, _ := net.ParseCIDR("::/0")
|
||||
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
|
||||
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute {
|
||||
return true
|
||||
}
|
||||
|
||||
return prefix.Contains(route.GetDestination().IP)
|
||||
}) {
|
||||
continue
|
||||
@ -166,8 +172,16 @@ func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client Mesh
|
||||
return peer
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) {
|
||||
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
|
||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
||||
ula := &ip.ULABuilder{}
|
||||
meshNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||
|
||||
routes := lib.Map(lib.MapKeys(m.getRoutes(mesh)), func(destination string) net.IPNet {
|
||||
_, ipNet, _ := net.ParseCIDR(destination)
|
||||
return *ipNet
|
||||
})
|
||||
routes = append(routes, *meshNet)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -184,23 +198,29 @@ func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNod
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedips := make([]net.IPNet, 1)
|
||||
_, ipnet, _ := net.ParseCIDR("::/0")
|
||||
allowedips[0] = *ipnet
|
||||
|
||||
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
||||
|
||||
peerCfgs[0] = wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: endpoint,
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
AllowedIPs: allowedips,
|
||||
AllowedIPs: routes,
|
||||
}
|
||||
|
||||
installedRoutes := make([]lib.Route, 0)
|
||||
|
||||
for _, route := range peerCfgs[0].AllowedIPs {
|
||||
installedRoutes = append(installedRoutes, lib.Route{
|
||||
Gateway: peer.GetWgHost().IP,
|
||||
Destination: route,
|
||||
})
|
||||
}
|
||||
|
||||
cfg := wgtypes.Config{
|
||||
Peers: peerCfgs,
|
||||
}
|
||||
|
||||
m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
@ -255,7 +275,9 @@ func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode,
|
||||
ula := &ip.ULABuilder{}
|
||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||
|
||||
if !ipNet.Contains(route.IP) {
|
||||
_, defaultRoute, _ := net.ParseCIDR("::/0")
|
||||
|
||||
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
|
||||
installedRoutes = append(installedRoutes, lib.Route{
|
||||
Gateway: n.GetWgHost().IP,
|
||||
Destination: route,
|
||||
@ -309,7 +331,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
case conf.PEER_ROLE:
|
||||
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
|
||||
case conf.CLIENT_ROLE:
|
||||
cfg, err = m.getClientConfig(mesh, peers, clients)
|
||||
cfg, err = m.getClientConfig(mesh, peers, clients, dev)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -471,7 +471,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
|
||||
m.RouteManager = params.RouteManager
|
||||
|
||||
if m.RouteManager == nil {
|
||||
m.RouteManager = NewRouteManager(m)
|
||||
m.RouteManager = NewRouteManager(m, ¶ms.Conf)
|
||||
}
|
||||
|
||||
m.idGenerator = params.IdGenerator
|
||||
|
@ -1,6 +1,9 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
@ -13,6 +16,7 @@ type RouteManager interface {
|
||||
|
||||
type RouteManagerImpl struct {
|
||||
meshManager MeshManager
|
||||
conf *conf.WgMeshConfiguration
|
||||
}
|
||||
|
||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
@ -32,12 +36,23 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
return err
|
||||
}
|
||||
|
||||
routes, err := mesh1.GetRoutes(pubKey.String())
|
||||
routeMap, err := mesh1.GetRoutes(pubKey.String())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.conf.AdvertiseDefaultRoute {
|
||||
_, ipv6Default, _ := net.ParseCIDR("::/0")
|
||||
|
||||
mesh1.AddRoutes(NodeID(self),
|
||||
&RouteStub{
|
||||
Destination: ipv6Default,
|
||||
HopCount: 0,
|
||||
Path: make([]string, 0),
|
||||
})
|
||||
}
|
||||
|
||||
for _, mesh2 := range meshes {
|
||||
if mesh1 == mesh2 {
|
||||
continue
|
||||
@ -50,7 +65,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
|
||||
routes := lib.MapValues(routeMap)
|
||||
|
||||
err = mesh2.AddRoutes(NodeID(self), append(routes, &RouteStub{
|
||||
Destination: ipNet,
|
||||
HopCount: 0,
|
||||
Path: make([]string, 0),
|
||||
@ -88,6 +105,6 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRouteManager(m MeshManager) RouteManager {
|
||||
return &RouteManagerImpl{meshManager: m}
|
||||
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
|
||||
return &RouteManagerImpl{meshManager: m, conf: conf}
|
||||
}
|
||||
|
@ -19,7 +19,11 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
|
||||
return err
|
||||
}
|
||||
|
||||
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
|
||||
ip6Routes := lib.Filter(routes, func(r lib.Route) bool {
|
||||
return r.Destination.IP.To4() == nil
|
||||
})
|
||||
|
||||
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, ip6Routes...)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -123,7 +123,7 @@ func (s *SyncerImpl) SyncMeshes() error {
|
||||
err := s.Sync(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,8 @@ import (
|
||||
// Run implements SyncScheduler.
|
||||
func syncFunction(syncer Syncer) lib.TimerFunc {
|
||||
return func() error {
|
||||
return syncer.SyncMeshes()
|
||||
syncer.SyncMeshes()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user