forked from extern/smegmesh
Compare commits
7 Commits
45-use-sta
...
51-bufix-n
Author | SHA1 | Date | |
---|---|---|---|
1e263cc6a8 | |||
dae9cd31a1 | |||
f855f53fbf | |||
52feb5767b | |||
815c4484ee | |||
0058c9f4c9 | |||
92c0805275 |
@ -449,7 +449,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes deletes the specified routes
|
// DeleteRoutes deletes the specified routes
|
||||||
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -467,7 +467,7 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
err = routeMap.Map().Delete(route)
|
err = routeMap.Map().Delete(route.GetDestination().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
@ -47,6 +47,8 @@ type WgMeshConfiguration struct {
|
|||||||
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
|
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
|
||||||
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
|
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
|
||||||
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
|
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
|
||||||
|
// AdvertiseDefaultRoute advertises a default route out of the mesh.
|
||||||
|
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
|
||||||
// Endpoint is the IP in which this computer is publicly reachable.
|
// Endpoint is the IP in which this computer is publicly reachable.
|
||||||
// usecase is when the node has multiple IP addresses
|
// usecase is when the node has multiple IP addresses
|
||||||
Endpoint string `yaml:"publicEndpoint"`
|
Endpoint string `yaml:"publicEndpoint"`
|
||||||
|
@ -179,8 +179,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
|
|||||||
|
|
||||||
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
||||||
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||||
|
nodes := m.store.AsList()
|
||||||
|
|
||||||
|
snapshot := make(map[string]MeshNode)
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
snapshot[node.PublicKey] = node
|
||||||
|
}
|
||||||
|
|
||||||
return &MeshSnapshot{
|
return &MeshSnapshot{
|
||||||
Nodes: m.store.AsMap(),
|
Nodes: snapshot,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,7 +320,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
if !m.store.Contains(nodeId) {
|
if !m.store.Contains(nodeId) {
|
||||||
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
}
|
}
|
||||||
@ -323,8 +331,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string)
|
|||||||
|
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
changes := false
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
delete(node.Routes, route)
|
changes = true
|
||||||
|
delete(node.Routes, route.GetDestination().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes {
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -408,7 +423,7 @@ func (m *TwoPhaseStoreMeshManager) Prune() error {
|
|||||||
|
|
||||||
// GetPeers: get a list of contactable peers
|
// GetPeers: get a list of contactable peers
|
||||||
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
||||||
nodes := lib.MapValues(m.store.AsMap())
|
nodes := m.store.AsList()
|
||||||
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
||||||
if mn.Type != string(conf.PEER_ROLE) {
|
if mn.Type != string(conf.PEER_ROLE) {
|
||||||
return false
|
return false
|
||||||
|
@ -18,9 +18,9 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
|
|||||||
Client: params.Client,
|
Client: params.Client,
|
||||||
conf: params.Conf,
|
conf: params.Conf,
|
||||||
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
|
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
|
||||||
h := fnv.New32a()
|
h := fnv.New64a()
|
||||||
h.Write([]byte(s))
|
h.Write([]byte(s))
|
||||||
return uint64(h.Sum32())
|
return h.Sum64()
|
||||||
}, uint64(3*params.Conf.KeepAliveTime)),
|
}, uint64(3*params.Conf.KeepAliveTime)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ type Bucket[D any] struct {
|
|||||||
// GMap is a set that can only grow in size
|
// GMap is a set that can only grow in size
|
||||||
type GMap[K cmp.Ordered, D any] struct {
|
type GMap[K cmp.Ordered, D any] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
contents map[K]Bucket[D]
|
contents map[uint64]Bucket[D]
|
||||||
clock *VectorClock[K]
|
clock *VectorClock[K]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ func (g *GMap[K, D]) Put(key K, value D) {
|
|||||||
|
|
||||||
clock := g.clock.IncrementClock()
|
clock := g.clock.IncrementClock()
|
||||||
|
|
||||||
g.contents[key] = Bucket[D]{
|
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
|
||||||
Vector: clock,
|
Vector: clock,
|
||||||
Contents: value,
|
Contents: value,
|
||||||
}
|
}
|
||||||
@ -33,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Contains(key K) bool {
|
func (g *GMap[K, D]) Contains(key K) bool {
|
||||||
|
return g.contains(g.clock.hashFunc(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) contains(key uint64) bool {
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
_, ok := g.contents[key]
|
_, ok := g.contents[key]
|
||||||
@ -42,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
|
|
||||||
if g.contents[key].Vector < b.Vector {
|
if g.contents[key].Vector < b.Vector {
|
||||||
@ -52,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
|||||||
g.lock.Unlock()
|
g.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) get(key K) Bucket[D] {
|
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
bucket := g.contents[key]
|
bucket := g.contents[key]
|
||||||
g.lock.RUnlock()
|
g.lock.RUnlock()
|
||||||
@ -61,14 +65,14 @@ func (g *GMap[K, D]) get(key K) Bucket[D] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Get(key K) D {
|
func (g *GMap[K, D]) Get(key K) D {
|
||||||
return g.get(key).Contents
|
return g.get(g.clock.hashFunc(key)).Contents
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Mark(key K) {
|
func (g *GMap[K, D]) Mark(key K) {
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
bucket := g.contents[key]
|
bucket := g.contents[g.clock.hashFunc(key)]
|
||||||
bucket.Gravestone = true
|
bucket.Gravestone = true
|
||||||
g.contents[key] = bucket
|
g.contents[g.clock.hashFunc(key)] = bucket
|
||||||
g.lock.Unlock()
|
g.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +82,7 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
|
|||||||
|
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
bucket, ok := g.contents[key]
|
bucket, ok := g.contents[g.clock.hashFunc(key)]
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
marked = bucket.Gravestone
|
marked = bucket.Gravestone
|
||||||
@ -89,10 +93,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
|
|||||||
return marked
|
return marked
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Keys() []K {
|
func (g *GMap[K, D]) Keys() []uint64 {
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
contents := make([]K, len(g.contents))
|
contents := make([]uint64, len(g.contents))
|
||||||
index := 0
|
index := 0
|
||||||
|
|
||||||
for key := range g.contents {
|
for key := range g.contents {
|
||||||
@ -104,8 +108,8 @@ func (g *GMap[K, D]) Keys() []K {
|
|||||||
return contents
|
return contents
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
|
||||||
buckets := make(map[K]Bucket[D])
|
buckets := make(map[uint64]Bucket[D])
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for key, value := range g.contents {
|
for key, value := range g.contents {
|
||||||
@ -116,8 +120,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
|||||||
return buckets
|
return buckets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
|
||||||
buckets := make(map[K]Bucket[D])
|
buckets := make(map[uint64]Bucket[D])
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@ -128,8 +132,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
|||||||
return buckets
|
return buckets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
|
||||||
clock := make(map[K]uint64)
|
clock := make(map[uint64]uint64)
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for key, bucket := range g.contents {
|
for key, bucket := range g.contents {
|
||||||
@ -166,7 +170,7 @@ func (g *GMap[K, D]) Prune() {
|
|||||||
|
|
||||||
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
|
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
|
||||||
return &GMap[K, D]{
|
return &GMap[K, D]{
|
||||||
contents: make(map[K]Bucket[D]),
|
contents: make(map[uint64]Bucket[D]),
|
||||||
clock: clock,
|
clock: clock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,19 +14,24 @@ type TwoPhaseMap[K cmp.Ordered, D any] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
|
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
|
||||||
Add map[K]Bucket[D]
|
Add map[uint64]Bucket[D]
|
||||||
Remove map[K]Bucket[bool]
|
Remove map[uint64]Bucket[bool]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains checks whether the value exists in the map
|
// Contains checks whether the value exists in the map
|
||||||
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
||||||
if !m.addMap.Contains(key) {
|
return m.contains(m.Clock.hashFunc(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks whether the value exists in the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
|
||||||
|
if !m.addMap.contains(key) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
addValue := m.addMap.get(key)
|
addValue := m.addMap.get(key)
|
||||||
|
|
||||||
if !m.removeMap.Contains(key) {
|
if !m.removeMap.contains(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +50,16 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
|
|||||||
return m.addMap.Get(key)
|
return m.addMap.Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
|
||||||
|
var result D
|
||||||
|
|
||||||
|
if !m.contains(key) {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.addMap.get(key).Contents
|
||||||
|
}
|
||||||
|
|
||||||
// Put places the key K in the map
|
// Put places the key K in the map
|
||||||
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||||
msgSequence := m.Clock.IncrementClock()
|
msgSequence := m.Clock.IncrementClock()
|
||||||
@ -61,13 +76,13 @@ func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
|||||||
m.removeMap.Put(key, true)
|
m.removeMap.Put(key, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Keys() []K {
|
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
|
||||||
keys := make([]K, 0)
|
keys := make([]uint64, 0)
|
||||||
|
|
||||||
addKeys := m.addMap.Keys()
|
addKeys := m.addMap.Keys()
|
||||||
|
|
||||||
for _, key := range addKeys {
|
for _, key := range addKeys {
|
||||||
if !m.Contains(key) {
|
if !m.contains(key) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
|
func (m *TwoPhaseMap[K, D]) AsList() []D {
|
||||||
theMap := make(map[K]D)
|
theList := make([]D, 0)
|
||||||
|
|
||||||
keys := m.Keys()
|
keys := m.keys()
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
theMap[key] = m.Get(key)
|
theList = append(theList, m.get(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
return theMap
|
return theList
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
||||||
@ -107,9 +122,9 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapState[K cmp.Ordered] struct {
|
type TwoPhaseMapState[K cmp.Ordered] struct {
|
||||||
Vectors map[K]uint64
|
Vectors map[uint64]uint64
|
||||||
AddContents map[K]uint64
|
AddContents map[uint64]uint64
|
||||||
RemoveContents map[K]uint64
|
RemoveContents map[uint64]uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
||||||
@ -120,7 +135,7 @@ func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
|||||||
// Sums the current values of the vectors. Provides good approximation
|
// Sums the current values of the vectors. Provides good approximation
|
||||||
// of increasing numbers
|
// of increasing numbers
|
||||||
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
||||||
return m.addMap.GetHash() + m.removeMap.GetHash()
|
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetState: get the current vector clock of the add and remove
|
// GetState: get the current vector clock of the add and remove
|
||||||
@ -136,16 +151,10 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) UpdateVector(state *TwoPhaseMapState[K]) {
|
|
||||||
for key, value := range state.Vectors {
|
|
||||||
m.Clock.Put(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
||||||
mapState := &TwoPhaseMapState[K]{
|
mapState := &TwoPhaseMapState[K]{
|
||||||
AddContents: make(map[K]uint64),
|
AddContents: make(map[uint64]uint64),
|
||||||
RemoveContents: make(map[K]uint64),
|
RemoveContents: make(map[uint64]uint64),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range state.AddContents {
|
for key, value := range state.AddContents {
|
||||||
@ -172,12 +181,12 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
|||||||
// Gravestone is local only to that node.
|
// Gravestone is local only to that node.
|
||||||
// Discover ourselves if the node is alive
|
// Discover ourselves if the node is alive
|
||||||
m.addMap.put(key, value)
|
m.addMap.put(key, value)
|
||||||
m.Clock.Put(key, value.Vector)
|
m.Clock.put(key, value.Vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range snapshot.Remove {
|
for key, value := range snapshot.Remove {
|
||||||
m.removeMap.put(key, value)
|
m.removeMap.put(key, value)
|
||||||
m.Clock.Put(key, value.Vector)
|
m.Clock.put(key, value.Vector)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
err := enc.Encode(hash)
|
err := enc.Encode(hash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
syncer.IncrementState()
|
syncer.IncrementState()
|
||||||
@ -59,7 +59,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
err := dec.Decode(&hash)
|
err := dec.Decode(&hash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// If vector clocks are equal then no need to merge state
|
// If vector clocks are equal then no need to merge state
|
||||||
@ -74,7 +74,7 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
err = enc.Encode(*syncer.mapState)
|
err = enc.Encode(*syncer.mapState)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
syncer.IncrementState()
|
syncer.IncrementState()
|
||||||
@ -93,10 +93,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
err := dec.Decode(&mapState)
|
err := dec.Decode(&mapState)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
difference := syncer.mapState.Difference(&mapState)
|
difference := syncer.mapState.Difference(&mapState)
|
||||||
|
syncer.manager.store.Clock.Merge(mapState.Vectors)
|
||||||
|
|
||||||
var sendBuffer bytes.Buffer
|
var sendBuffer bytes.Buffer
|
||||||
enc := gob.NewEncoder(&sendBuffer)
|
enc := gob.NewEncoder(&sendBuffer)
|
||||||
@ -163,7 +164,7 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
|||||||
|
|
||||||
func (t *TwoPhaseSyncer) Complete() {
|
func (t *TwoPhaseSyncer) Complete() {
|
||||||
logging.Log.WriteInfof("SYNC COMPLETED")
|
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||||
if t.state == FINISHED || t.state == MERGE {
|
if t.state >= MERGE {
|
||||||
t.manager.store.Clock.IncrementClock()
|
t.manager.store.Clock.IncrementClock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package crdt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"slices"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -19,7 +18,7 @@ type VectorBucket struct {
|
|||||||
// Vector clock defines an abstract data type
|
// Vector clock defines an abstract data type
|
||||||
// for a vector clock implementation
|
// for a vector clock implementation
|
||||||
type VectorClock[K cmp.Ordered] struct {
|
type VectorClock[K cmp.Ordered] struct {
|
||||||
vectors map[K]*VectorBucket
|
vectors map[uint64]*VectorBucket
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
processID K
|
processID K
|
||||||
staleTime uint64
|
staleTime uint64
|
||||||
@ -40,7 +39,7 @@ func (m *VectorClock[K]) IncrementClock() uint64 {
|
|||||||
lastUpdate: uint64(time.Now().Unix()),
|
lastUpdate: uint64(time.Now().Unix()),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.vectors[m.processID] = &newBucket
|
m.vectors[m.hashFunc(m.processID)] = &newBucket
|
||||||
|
|
||||||
m.lock.Unlock()
|
m.lock.Unlock()
|
||||||
return maxClock
|
return maxClock
|
||||||
@ -53,26 +52,28 @@ func (m *VectorClock[K]) GetHash() uint64 {
|
|||||||
|
|
||||||
hash := uint64(0)
|
hash := uint64(0)
|
||||||
|
|
||||||
sortedKeys := lib.MapKeys(m.vectors)
|
|
||||||
slices.Sort(sortedKeys)
|
|
||||||
|
|
||||||
for key, bucket := range m.vectors {
|
for key, bucket := range m.vectors {
|
||||||
hash += m.hashFunc(key)
|
hash += key * (bucket.clock + 1)
|
||||||
hash += bucket.clock
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.RUnlock()
|
m.lock.RUnlock()
|
||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
|
||||||
|
for key, value := range vectors {
|
||||||
|
m.put(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getStale: get all entries that are stale within the mesh
|
// getStale: get all entries that are stale within the mesh
|
||||||
func (m *VectorClock[K]) getStale() []K {
|
func (m *VectorClock[K]) getStale() []uint64 {
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
|
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
|
||||||
return max(i, vb.lastUpdate)
|
return max(i, vb.lastUpdate)
|
||||||
})
|
})
|
||||||
|
|
||||||
toRemove := make([]K, 0)
|
toRemove := make([]uint64, 0)
|
||||||
|
|
||||||
for key, bucket := range m.vectors {
|
for key, bucket := range m.vectors {
|
||||||
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
|
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
|
||||||
@ -97,10 +98,19 @@ func (m *VectorClock[K]) Prune() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
||||||
return m.vectors[processId].lastUpdate
|
m.lock.RLock()
|
||||||
|
|
||||||
|
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return lastUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *VectorClock[K]) Put(key K, value uint64) {
|
func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||||
|
m.put(m.hashFunc(key), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) put(key uint64, value uint64) {
|
||||||
clockValue := uint64(0)
|
clockValue := uint64(0)
|
||||||
|
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
@ -121,16 +131,13 @@ func (m *VectorClock[K]) Put(key K, value uint64) {
|
|||||||
m.lock.Unlock()
|
m.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *VectorClock[K]) GetClock() map[K]uint64 {
|
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
|
||||||
clock := make(map[K]uint64)
|
clock := make(map[uint64]uint64)
|
||||||
|
|
||||||
m.lock.RLock()
|
m.lock.RLock()
|
||||||
|
|
||||||
keys := lib.MapKeys(m.vectors)
|
for key, value := range m.vectors {
|
||||||
slices.Sort(keys)
|
clock[key] = value.clock
|
||||||
|
|
||||||
for key, value := range clock {
|
|
||||||
clock[key] = value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.RUnlock()
|
m.lock.RUnlock()
|
||||||
@ -139,7 +146,7 @@ func (m *VectorClock[K]) GetClock() map[K]uint64 {
|
|||||||
|
|
||||||
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
|
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
|
||||||
return &VectorClock[K]{
|
return &VectorClock[K]{
|
||||||
vectors: make(map[K]*VectorBucket),
|
vectors: make(map[uint64]*VectorBucket),
|
||||||
processID: processID,
|
processID: processID,
|
||||||
staleTime: staleTime,
|
staleTime: staleTime,
|
||||||
hashFunc: hashFunc,
|
hashFunc: hashFunc,
|
||||||
|
@ -7,6 +7,27 @@ func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
|
|||||||
return MapValuesWithExclude(m, map[K]struct{}{})
|
return MapValuesWithExclude(m, map[K]struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MapItemsEntry[K cmp.Ordered, V any] struct {
|
||||||
|
Key K
|
||||||
|
Value V
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
|
||||||
|
keys := MapKeys(m)
|
||||||
|
values := MapValues(m)
|
||||||
|
|
||||||
|
vs := make([]MapItemsEntry[K, V], len(keys))
|
||||||
|
|
||||||
|
for index, _ := range keys {
|
||||||
|
vs[index] = MapItemsEntry[K, V]{
|
||||||
|
Key: keys[index],
|
||||||
|
Value: values[index],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
||||||
values := make([]V, len(m)-len(exclude))
|
values := make([]V, len(m)-len(exclude))
|
||||||
|
|
||||||
|
@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
|
|||||||
family = unix.AF_INET
|
family = unix.AF_INET
|
||||||
}
|
}
|
||||||
|
|
||||||
attr := rtnetlink.RouteAttributes{
|
routes, err := c.listRoutes(ifName, family)
|
||||||
Dst: dst.IP,
|
|
||||||
OutIface: uint32(iface.Index),
|
|
||||||
Gateway: gw,
|
|
||||||
}
|
|
||||||
|
|
||||||
ones, _ := dst.Mask.Size()
|
|
||||||
|
|
||||||
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
|
||||||
Family: family,
|
|
||||||
Table: unix.RT_TABLE_MAIN,
|
|
||||||
Protocol: unix.RTPROT_BOOT,
|
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
|
||||||
Type: unix.RTN_UNICAST,
|
|
||||||
DstLength: uint8(ones),
|
|
||||||
Attributes: attr,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add route %w", err)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it already exists no need to add the route
|
||||||
|
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
|
||||||
|
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
|
||||||
|
prevRoute.Attributes.Gateway.Equal(route.Gateway)
|
||||||
|
}) {
|
||||||
|
attr := rtnetlink.RouteAttributes{
|
||||||
|
Dst: dst.IP,
|
||||||
|
OutIface: uint32(iface.Index),
|
||||||
|
Gateway: gw,
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := dst.Mask.Size()
|
||||||
|
|
||||||
|
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
||||||
|
Family: family,
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
Protocol: unix.RTPROT_BOOT,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
DstLength: uint8(ones),
|
||||||
|
Attributes: attr,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -248,6 +260,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
|||||||
if route.equal(r) {
|
if route.equal(r) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -255,7 +275,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
|||||||
toDelete := Filter(ifRoutes, shouldExclude)
|
toDelete := Filter(ifRoutes, shouldExclude)
|
||||||
|
|
||||||
for _, route := range toDelete {
|
for _, route := range toDelete {
|
||||||
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
|
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
|
||||||
err := c.DeleteRoute(ifName, route)
|
err := c.DeleteRoute(ifName, route)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@ -27,6 +26,7 @@ type WgMeshConfigApplyer struct {
|
|||||||
meshManager MeshManager
|
meshManager MeshManager
|
||||||
config *conf.WgMeshConfiguration
|
config *conf.WgMeshConfiguration
|
||||||
routeInstaller route.RouteInstaller
|
routeInstaller route.RouteInstaller
|
||||||
|
hashFunc func(MeshNode) int
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeNode struct {
|
type routeNode struct {
|
||||||
@ -34,16 +34,11 @@ type routeNode struct {
|
|||||||
route Route
|
route Route
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
|
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
|
||||||
|
device *wgtypes.Device,
|
||||||
peerToClients map[string][]net.IPNet,
|
peerToClients map[string][]net.IPNet,
|
||||||
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
||||||
|
|
||||||
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := node.GetPublicKey()
|
pubKey, err := node.GetPublicKey()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -66,17 +61,12 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
if len(bestRoutes) == 1 {
|
if len(bestRoutes) == 1 {
|
||||||
pickedRoute = bestRoutes[0]
|
pickedRoute = bestRoutes[0]
|
||||||
} else if len(bestRoutes) > 1 {
|
} else if len(bestRoutes) > 1 {
|
||||||
keyFunc := func(mn MeshNode) int {
|
|
||||||
pubKey, _ := mn.GetPublicKey()
|
|
||||||
return lib.HashString(pubKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
bucketFunc := func(rn routeNode) int {
|
bucketFunc := func(rn routeNode) int {
|
||||||
return lib.HashString(rn.gateway)
|
return lib.HashString(rn.gateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Else there is more than one candidate so consistently hash
|
// Else there is more than one candidate so consistently hash
|
||||||
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
|
pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pickedRoute.gateway == pubKey.String() {
|
if pickedRoute.gateway == pubKey.String() {
|
||||||
@ -91,6 +81,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
return p.PublicKey.String() == pubKey.String()
|
return p.PublicKey.String() == pubKey.String()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't override the existing IP in case it already exists
|
||||||
if existing != -1 {
|
if existing != -1 {
|
||||||
endpoint = device.Peers[existing].Endpoint
|
endpoint = device.Peers[existing].Endpoint
|
||||||
}
|
}
|
||||||
@ -110,13 +107,15 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
// consistently hash to evenly spread the distribution of traffic
|
// consistently hash to evenly spread the distribution of traffic
|
||||||
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
|
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
|
||||||
mesh, _ := meshProvider.GetMesh()
|
mesh, _ := meshProvider.GetMesh()
|
||||||
|
|
||||||
routes := make(map[string][]routeNode)
|
routes := make(map[string][]routeNode)
|
||||||
|
|
||||||
|
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
|
||||||
|
return p.GetType() == conf.PEER_ROLE
|
||||||
|
})
|
||||||
|
|
||||||
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
|
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
|
||||||
ula := &ip.ULABuilder{}
|
ula := &ip.ULABuilder{}
|
||||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||||
|
|
||||||
return ipNet
|
return ipNet
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -125,6 +124,13 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
|
|
||||||
for _, route := range node.GetRoutes() {
|
for _, route := range node.GetRoutes() {
|
||||||
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
||||||
|
v6Default, _, _ := net.ParseCIDR("::/0")
|
||||||
|
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
|
|
||||||
|
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return prefix.Contains(route.GetDestination().IP)
|
return prefix.Contains(route.GetDestination().IP)
|
||||||
}) {
|
}) {
|
||||||
continue
|
continue
|
||||||
@ -138,6 +144,24 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
route: route,
|
route: route,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Client's only acessible by another peer
|
||||||
|
if node.GetType() == conf.CLIENT_ROLE {
|
||||||
|
peer := m.getCorrespondingPeer(peers, node)
|
||||||
|
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
|
||||||
|
|
||||||
|
// If the node isn't the self use that peer as the gateway
|
||||||
|
if !NodeEquals(peer, self) {
|
||||||
|
peerPub, _ := peer.GetPublicKey()
|
||||||
|
rn.gateway = peerPub.String()
|
||||||
|
rn.route = &RouteStub{
|
||||||
|
Destination: rn.route.GetDestination(),
|
||||||
|
HopCount: rn.route.GetHopCount() + 1,
|
||||||
|
// Append the path to this peer
|
||||||
|
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
otherRoute = make([]routeNode, 1)
|
otherRoute = make([]routeNode, 1)
|
||||||
otherRoute[0] = rn
|
otherRoute[0] = rn
|
||||||
@ -145,8 +169,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
||||||
otherRoute[0] = rn
|
otherRoute[0] = rn
|
||||||
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
||||||
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
|
|
||||||
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
|
|
||||||
routes[destination] = append(otherRoute, rn)
|
routes[destination] = append(otherRoute, rn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -157,26 +179,57 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
|
|
||||||
// getCorrespondignPeer: gets the peer corresponding to the client
|
// getCorrespondignPeer: gets the peer corresponding to the client
|
||||||
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
|
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
|
||||||
hashFunc := func(mn MeshNode) int {
|
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
|
||||||
pubKey, _ := mn.GetPublicKey()
|
|
||||||
return lib.HashString(pubKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc)
|
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) {
|
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
|
||||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
peers := dev.Peers
|
||||||
|
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
|
||||||
|
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
|
||||||
|
return p1.PublicKey.String() == p2.PublicKey.String()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
|
||||||
|
return wgtypes.PeerConfig{
|
||||||
|
PublicKey: p.PublicKey,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetConfigParams struct {
|
||||||
|
mesh MeshProvider
|
||||||
|
peers []MeshNode
|
||||||
|
clients []MeshNode
|
||||||
|
dev *wgtypes.Device
|
||||||
|
routes map[string][]routeNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
|
||||||
|
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||||
|
ula := &ip.ULABuilder{}
|
||||||
|
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
|
||||||
|
|
||||||
|
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
|
||||||
|
return lib.Filter(rns, func(rn routeNode) bool {
|
||||||
|
ip, _, _ := net.ParseCIDR(rn.gateway)
|
||||||
|
return meshNet.Contains(ip)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
|
||||||
|
return *rs[0].route.GetDestination()
|
||||||
|
})
|
||||||
|
routes = append(routes, *meshNet)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := m.getCorrespondingPeer(peers, self)
|
peer := m.getCorrespondingPeer(params.peers, self)
|
||||||
|
|
||||||
pubKey, _ := peer.GetPublicKey()
|
pubKey, _ := peer.GetPublicKey()
|
||||||
|
|
||||||
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
||||||
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
||||||
|
|
||||||
@ -184,40 +237,66 @@ func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNod
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedips := make([]net.IPNet, 1)
|
|
||||||
_, ipnet, _ := net.ParseCIDR("::/0")
|
|
||||||
allowedips[0] = *ipnet
|
|
||||||
|
|
||||||
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
||||||
|
|
||||||
peerCfgs[0] = wgtypes.PeerConfig{
|
peerCfgs[0] = wgtypes.PeerConfig{
|
||||||
PublicKey: pubKey,
|
PublicKey: pubKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
AllowedIPs: allowedips,
|
AllowedIPs: routes,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
installedRoutes := make([]lib.Route, 0)
|
||||||
|
|
||||||
|
for _, route := range peerCfgs[0].AllowedIPs {
|
||||||
|
installedRoutes = append(installedRoutes, lib.Route{
|
||||||
|
Gateway: peer.GetWgHost().IP,
|
||||||
|
Destination: route,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerCfgs,
|
Peers: peerCfgs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
||||||
return &cfg, err
|
return &cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
|
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
|
||||||
|
routes := make([]lib.Route, 0)
|
||||||
|
|
||||||
|
for _, route := range wgNode.AllowedIPs {
|
||||||
|
ula := &ip.ULABuilder{}
|
||||||
|
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||||
|
|
||||||
|
_, defaultRoute, _ := net.ParseCIDR("::/0")
|
||||||
|
|
||||||
|
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
|
||||||
|
routes = append(routes, lib.Route{
|
||||||
|
Gateway: node.GetWgHost().IP,
|
||||||
|
Destination: route,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
|
||||||
peerToClients := make(map[string][]net.IPNet)
|
peerToClients := make(map[string][]net.IPNet)
|
||||||
routes := m.getRoutes(mesh)
|
|
||||||
installedRoutes := make([]lib.Route, 0)
|
installedRoutes := make([]lib.Route, 0)
|
||||||
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
||||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, n := range clients {
|
for _, n := range params.clients {
|
||||||
if len(peers) > 0 {
|
if len(params.peers) > 0 {
|
||||||
peer := m.getCorrespondingPeer(peers, n)
|
peer := m.getCorrespondingPeer(params.peers, n)
|
||||||
pubKey, _ := peer.GetPublicKey()
|
pubKey, _ := peer.GetPublicKey()
|
||||||
clients, ok := peerToClients[pubKey.String()]
|
clients, ok := peerToClients[pubKey.String()]
|
||||||
|
|
||||||
@ -229,53 +308,42 @@ func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode,
|
|||||||
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
||||||
|
|
||||||
if NodeEquals(self, peer) {
|
if NodeEquals(self, peer) {
|
||||||
cfg, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
|
||||||
peerConfigs = append(peerConfigs, *cfg)
|
peerConfigs = append(peerConfigs, *cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, n := range peers {
|
for _, n := range params.peers {
|
||||||
if NodeEquals(n, self) {
|
if NodeEquals(n, self) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range peer.AllowedIPs {
|
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
|
||||||
ula := &ip.ULABuilder{}
|
|
||||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
|
||||||
|
|
||||||
if !ipNet.Contains(route.IP) {
|
|
||||||
installedRoutes = append(installedRoutes, lib.Route{
|
|
||||||
Gateway: n.GetWgHost().IP,
|
|
||||||
Destination: route,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConfigs = append(peerConfigs, *peer)
|
peerConfigs = append(peerConfigs, *peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerConfigs,
|
Peers: peerConfigs,
|
||||||
ReplacePeers: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
||||||
return &cfg, err
|
return &cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
|
||||||
snap, err := mesh.GetMesh()
|
snap, err := mesh.GetMesh()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -305,17 +373,28 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
|
|
||||||
var cfg *wgtypes.Config = nil
|
var cfg *wgtypes.Config = nil
|
||||||
|
|
||||||
|
configParams := &GetConfigParams{
|
||||||
|
mesh: mesh,
|
||||||
|
peers: peers,
|
||||||
|
clients: clients,
|
||||||
|
dev: dev,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
|
||||||
switch self.GetType() {
|
switch self.GetType() {
|
||||||
case conf.PEER_ROLE:
|
case conf.PEER_ROLE:
|
||||||
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
|
cfg, err = m.getPeerConfig(configParams)
|
||||||
case conf.CLIENT_ROLE:
|
case conf.CLIENT_ROLE:
|
||||||
cfg, err = m.getClientConfig(mesh, peers, clients)
|
cfg, err = m.getClientConfig(configParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
|
||||||
|
cfg.Peers = append(cfg.Peers, toRemove...)
|
||||||
|
|
||||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -325,9 +404,36 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
|
||||||
|
allRoutes := make(map[string][]routeNode)
|
||||||
|
|
||||||
for _, mesh := range m.meshManager.GetMeshes() {
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
err := m.updateWgConf(mesh)
|
routes := m.getRoutes(mesh)
|
||||||
|
|
||||||
|
for destination, route := range routes {
|
||||||
|
_, ok := allRoutes[destination]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = append(allRoutes[destination], route...)
|
||||||
|
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||||
|
allRoutes := m.getAllRoutes()
|
||||||
|
|
||||||
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
|
err := m.updateWgConf(mesh, allRoutes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -366,5 +472,9 @@ func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer
|
|||||||
return &WgMeshConfigApplyer{
|
return &WgMeshConfigApplyer{
|
||||||
config: config,
|
config: config,
|
||||||
routeInstaller: route.NewRouteInstaller(),
|
routeInstaller: route.NewRouteInstaller(),
|
||||||
|
hashFunc: func(mn MeshNode) int {
|
||||||
|
pubKey, _ := mn.GetPublicKey()
|
||||||
|
return lib.HashString(pubKey.String())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,7 +276,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.Meshes[params.MeshId].AddNode(node)
|
s.Meshes[params.MeshId].AddNode(node)
|
||||||
return s.RouteManager.UpdateRoutes()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LeaveMesh leaves the mesh network
|
// LeaveMesh leaves the mesh network
|
||||||
@ -287,10 +287,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
|||||||
return fmt.Errorf("mesh %s does not exist", meshId)
|
return fmt.Errorf("mesh %s does not exist", meshId)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
||||||
|
|
||||||
s.RouteManager.RemoveRoutes(meshId)
|
|
||||||
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -471,7 +468,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
|
|||||||
m.RouteManager = params.RouteManager
|
m.RouteManager = params.RouteManager
|
||||||
|
|
||||||
if m.RouteManager == nil {
|
if m.RouteManager == nil {
|
||||||
m.RouteManager = NewRouteManager(m)
|
m.RouteManager = NewRouteManager(m, ¶ms.Conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.idGenerator = params.IdGenerator
|
m.idGenerator = params.IdGenerator
|
||||||
|
@ -1,23 +1,25 @@
|
|||||||
package mesh
|
package mesh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteManager interface {
|
type RouteManager interface {
|
||||||
UpdateRoutes() error
|
UpdateRoutes() error
|
||||||
RemoveRoutes(meshId string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteManagerImpl struct {
|
type RouteManagerImpl struct {
|
||||||
meshManager MeshManager
|
meshManager MeshManager
|
||||||
|
conf *conf.WgMeshConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||||
meshes := r.meshManager.GetMeshes()
|
meshes := r.meshManager.GetMeshes()
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
routes := make(map[string][]Route)
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
for _, mesh1 := range meshes {
|
||||||
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
||||||
@ -26,68 +28,84 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := self.GetPublicKey()
|
if _, ok := routes[mesh1.GetMeshId()]; !ok {
|
||||||
|
routes[mesh1.GetMeshId()] = make([]Route, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
routeMap, err := mesh1.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := mesh1.GetRoutes(pubKey.String())
|
if r.conf.AdvertiseDefaultRoute {
|
||||||
|
_, ipv6Default, _ := net.ParseCIDR("::/0")
|
||||||
|
|
||||||
if err != nil {
|
mesh1.AddRoutes(NodeID(self),
|
||||||
return err
|
&RouteStub{
|
||||||
|
Destination: ipv6Default,
|
||||||
|
HopCount: 0,
|
||||||
|
Path: make([]string, 0),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, mesh2 := range meshes {
|
for _, mesh2 := range meshes {
|
||||||
|
routeValues, ok := routes[mesh2.GetMeshId()]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
routeValues = make([]Route, 0)
|
||||||
|
}
|
||||||
|
|
||||||
if mesh1 == mesh2 {
|
if mesh1 == mesh2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
|
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, &RouteStub{
|
||||||
logging.Log.WriteErrorf(err.Error())
|
Destination: mesh1IpNet,
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
|
|
||||||
Destination: ipNet,
|
|
||||||
HopCount: 0,
|
HopCount: 0,
|
||||||
Path: make([]string, 0),
|
Path: []string{mesh1.GetMeshId()},
|
||||||
})...)
|
})
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, lib.MapValues(routeMap)...)
|
||||||
return err
|
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
|
||||||
|
routeValues = lib.Filter(routeValues, func(r Route) bool {
|
||||||
|
pathNotMesh := func(s string) bool {
|
||||||
|
return s == mesh2.GetMeshId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the route does not see it's own IP
|
||||||
|
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh)
|
||||||
|
})
|
||||||
|
|
||||||
|
routes[mesh2.GetMeshId()] = routeValues
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the set different of each, working out routes to remove and to keep.
|
||||||
|
for meshId, meshRoutes := range routes {
|
||||||
|
mesh := r.meshManager.GetMesh(meshId)
|
||||||
|
self, _ := r.meshManager.GetSelf(meshId)
|
||||||
|
toRemove := make([]Route, 0)
|
||||||
|
|
||||||
|
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
|
for _, route := range prevRoutes {
|
||||||
|
if !lib.Contains(meshRoutes, func(r Route) bool {
|
||||||
|
return RouteEquals(r, route)
|
||||||
|
}) {
|
||||||
|
toRemove = append(toRemove, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mesh.RemoveRoutes(NodeID(self), toRemove...)
|
||||||
|
mesh.AddRoutes(NodeID(self), meshRoutes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeRoutes: removes all meshes we are no longer a part of
|
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager {
|
||||||
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
|
return &RouteManagerImpl{meshManager: m, conf: conf}
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
|
||||||
meshes := r.meshManager.GetMeshes()
|
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
|
||||||
self, err := r.meshManager.GetSelf(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRouteManager(m MeshManager) RouteManager {
|
|
||||||
return &RouteManagerImpl{meshManager: m}
|
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRoutes implements MeshProvider.
|
// RemoveRoutes implements MeshProvider.
|
||||||
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
|
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ package mesh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@ -19,6 +20,12 @@ type Route interface {
|
|||||||
GetPath() []string
|
GetPath() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RouteEquals(r1, r2 Route) bool {
|
||||||
|
return r1.GetDestination().String() == r2.GetDestination().String() &&
|
||||||
|
r1.GetHopCount() == r2.GetHopCount() &&
|
||||||
|
slices.Equal(r1.GetPath(), r2.GetPath())
|
||||||
|
}
|
||||||
|
|
||||||
type RouteStub struct {
|
type RouteStub struct {
|
||||||
Destination *net.IPNet
|
Destination *net.IPNet
|
||||||
HopCount int
|
HopCount int
|
||||||
@ -71,11 +78,6 @@ func NodeEquals(node1, node2 MeshNode) bool {
|
|||||||
return key1.String() == key2.String()
|
return key1.String() == key2.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RouteEquals(route1, route2 Route) bool {
|
|
||||||
return route1.GetDestination().String() == route2.GetDestination().String() &&
|
|
||||||
route1.GetHopCount() == route2.GetHopCount()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodeID(node MeshNode) string {
|
func NodeID(node MeshNode) string {
|
||||||
key, _ := node.GetPublicKey()
|
key, _ := node.GetPublicKey()
|
||||||
return key.String()
|
return key.String()
|
||||||
@ -116,7 +118,7 @@ type MeshProvider interface {
|
|||||||
// AddRoutes: adds routes to the given node
|
// AddRoutes: adds routes to the given node
|
||||||
AddRoutes(nodeId string, route ...Route) error
|
AddRoutes(nodeId string, route ...Route) error
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
RemoveRoutes(nodeId string, route ...string) error
|
RemoveRoutes(nodeId string, route ...Route) error
|
||||||
// GetSyncer: returns the automerge syncer for sync
|
// GetSyncer: returns the automerge syncer for sync
|
||||||
GetSyncer() MeshSyncer
|
GetSyncer() MeshSyncer
|
||||||
// GetNode get a particular not within the mesh
|
// GetNode get a particular not within the mesh
|
||||||
@ -173,7 +175,7 @@ type MeshProviderFactory interface {
|
|||||||
// MeshNodeFactoryParams are the parameters required to construct
|
// MeshNodeFactoryParams are the parameters required to construct
|
||||||
// a mesh node
|
// a mesh node
|
||||||
type MeshNodeFactoryParams struct {
|
type MeshNodeFactoryParams struct {
|
||||||
PublicKey *wgtypes.Key
|
PublicKey *wgtypes.Key
|
||||||
NodeIP net.IP
|
NodeIP net.IP
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
@ -30,15 +30,12 @@ type SyncerImpl struct {
|
|||||||
|
|
||||||
// Sync: Sync random nodes
|
// Sync: Sync random nodes
|
||||||
func (s *SyncerImpl) Sync(meshId string) error {
|
func (s *SyncerImpl) Sync(meshId string) error {
|
||||||
self, err := s.manager.GetSelf(meshId)
|
// Self can be nil if the node is removed
|
||||||
|
self, _ := s.manager.GetSelf(meshId)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.manager.GetMesh(meshId).Prune()
|
s.manager.GetMesh(meshId).Prune()
|
||||||
|
|
||||||
if self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
||||||
logging.Log.WriteInfof("No changes for %s", meshId)
|
logging.Log.WriteInfof("No changes for %s", meshId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -52,10 +49,16 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
|
|
||||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
||||||
|
|
||||||
|
if self != nil {
|
||||||
|
nodeNames = lib.Filter(nodeNames, func(s string) bool {
|
||||||
|
return s != mesh.NodeID(self)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
var gossipNodes []string
|
var gossipNodes []string
|
||||||
|
|
||||||
// Clients always pings its peer for configuration
|
// Clients always pings its peer for configuration
|
||||||
if self.GetType() == conf.CLIENT_ROLE {
|
if self != nil && self.GetType() == conf.CLIENT_ROLE {
|
||||||
keyFunc := lib.HashString
|
keyFunc := lib.HashString
|
||||||
bucketFunc := lib.HashString
|
bucketFunc := lib.HashString
|
||||||
|
|
||||||
@ -108,7 +111,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
|||||||
s.lastSync = uint64(time.Now().Unix())
|
s.lastSync = uint64(time.Now().Unix())
|
||||||
|
|
||||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||||
err = s.manager.ApplyConfig()
|
err := s.manager.ApplyConfig()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||||
@ -123,7 +126,7 @@ func (s *SyncerImpl) SyncMeshes() error {
|
|||||||
err := s.Sync(meshId)
|
err := s.Sync(meshId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ import (
|
|||||||
// Run implements SyncScheduler.
|
// Run implements SyncScheduler.
|
||||||
func syncFunction(syncer Syncer) lib.TimerFunc {
|
func syncFunction(syncer Syncer) lib.TimerFunc {
|
||||||
return func() error {
|
return func() error {
|
||||||
return syncer.SyncMeshes()
|
syncer.SyncMeshes()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user