diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index c77debc..d4c6143 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -1,4 +1,4 @@ -package crdt +package automerge import ( "errors" @@ -6,7 +6,6 @@ import ( "net" "slices" "strings" - "sync" "time" "github.com/automerge/automerge-go" @@ -20,7 +19,6 @@ import ( // CrdtMeshManager manages nodes in the crdt mesh type CrdtMeshManager struct { - lock sync.RWMutex MeshId string IfName string Client *wgctrl.Client @@ -42,13 +40,10 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { crdt.Services = make(map[string]string) crdt.Timestamp = time.Now().Unix() - c.lock.Lock() c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) - c.lock.Unlock() } func (c *CrdtMeshManager) isPeer(nodeId string) bool { - c.lock.RLock() node, err := c.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -56,7 +51,6 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool { } nodeType, err := node.Map().Get("type") - c.lock.RUnlock() if err != nil || nodeType.Kind() != automerge.KindStr { return false @@ -68,7 +62,6 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool { // isAlive: checks that the node's configuration has been updated // since the rquired keep alive time func (c *CrdtMeshManager) isAlive(nodeId string) bool { - c.lock.RLock() node, err := c.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -76,7 +69,6 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool { } timestamp, err := node.Map().Get("timestamp") - c.lock.RUnlock() if err != nil || timestamp.Kind() != automerge.KindInt64 { return false @@ -87,9 +79,7 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool { } func (c *CrdtMeshManager) GetPeers() []string { - c.lock.RLock() keys, _ := c.doc.Path("nodes").Map().Keys() - c.lock.RUnlock() keys = lib.Filter(keys, func(publicKey string) bool { return c.isPeer(publicKey) && c.isAlive(publicKey) @@ -108,9 +98,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { if c.cache == nil || len(changes) > 0 { c.lastCacheHash = c.LastHash - c.lock.RLock() cache, err := automerge.As[*MeshCrdt](c.doc.Root()) - c.lock.RUnlock() if err != nil { return nil, err @@ -170,7 +158,6 @@ func (m *CrdtMeshManager) NodeExists(key string) bool { } func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(endpoint) if node.Kind() != automerge.KindMap { @@ -182,7 +169,6 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { } meshNode, err := automerge.As[*MeshNodeCrdt](node) - m.lock.RUnlock() if err != nil { return nil, err @@ -228,9 +214,7 @@ func (m *CrdtMeshManager) SaveChanges() { } func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) - m.lock.RUnlock() if err != nil { return err @@ -240,9 +224,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { return errors.New("node is not a map") } - m.lock.Lock() err = node.Map().Set("timestamp", time.Now().Unix()) - m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Timestamp Updated for %s", nodeId) @@ -252,9 +234,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { } func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) - m.lock.RUnlock() if err != nil { return err @@ -264,9 +244,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro return fmt.Errorf("%s does not exist", nodeId) } - m.lock.Lock() err = node.Map().Set("description", description) - m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Description Updated for %s", nodeId) @@ -276,9 +254,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro } func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) - m.lock.RUnlock() if err != nil { return err @@ -288,9 +264,7 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { return fmt.Errorf("%s does not exist", nodeId) } - m.lock.Lock() err = node.Map().Set("alias", alias) - m.lock.Unlock() if err == nil { logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias) @@ -300,17 +274,13 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { } func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) - m.lock.RUnlock() if err != nil || node.Kind() != automerge.KindMap { return fmt.Errorf("AddService: node %s does not exist", nodeId) } - m.lock.RLock() service, err := node.Map().Get("services") - m.lock.RUnlock() if err != nil { return err @@ -320,14 +290,11 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { return fmt.Errorf("AddService: services property does not exist in node") } - m.lock.Lock() err = service.Map().Set(key, value) - m.lock.Unlock() return err } func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { - m.lock.RLock() node, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil || node.Kind() != automerge.KindMap { @@ -343,11 +310,8 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { if service.Kind() != automerge.KindMap { return fmt.Errorf("services property does not exist") } - m.lock.RUnlock() - m.lock.Lock() err = service.Map().Delete(key) - m.lock.Unlock() if err != nil { return fmt.Errorf("service %s does not exist", key) @@ -358,7 +322,6 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { // AddRoutes: adds routes to the specific nodeId func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { - m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId) @@ -371,7 +334,6 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { } routeMap, err := nodeVal.Map().Get("routes") - m.lock.RUnlock() if err != nil { return err @@ -400,12 +362,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { slices.Equal(route.GetPath(), pathStr) } - m.lock.Lock() err = routeMap.Map().Set(route.GetDestination().String(), Route{ Destination: route.GetDestination().String(), Path: route.GetPath(), }) - m.lock.Unlock() if err != nil { return err @@ -415,7 +375,6 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { } func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { - m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { @@ -437,7 +396,6 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { } routes, err := automerge.As[map[string]Route](routeMap) - m.lock.RUnlock() return lib.MapValues(routes), err } @@ -486,15 +444,12 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e } func (m *CrdtMeshManager) RemoveNode(nodeId string) error { - m.lock.Lock() err := m.doc.Path("nodes").Map().Delete(nodeId) - m.lock.Unlock() return err } // DeleteRoutes deletes the specified routes func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { - m.lock.RLock() nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) if err != nil { @@ -506,17 +461,14 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { } routeMap, err := nodeVal.Map().Get("routes") - m.lock.RUnlock() if err != nil { return err } - m.lock.Lock() for _, route := range routes { err = routeMap.Map().Delete(route) } - m.lock.Unlock() return err } @@ -526,7 +478,6 @@ func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { } func (m *CrdtMeshManager) Prune(pruneTime int) error { - m.lock.RLock() nodes, err := m.doc.Path("nodes").Get() if err != nil { @@ -538,7 +489,6 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error { } values, err := nodes.Map().Values() - m.lock.RUnlock() if err != nil { return err @@ -553,9 +503,7 @@ func (m *CrdtMeshManager) Prune(pruneTime int) error { nodeMap := node.Map() - m.lock.RLock() timeStamp, err := nodeMap.Get("timestamp") - m.lock.RUnlock() if err != nil { return err @@ -601,7 +549,6 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet { _, ipnet, err := net.ParseCIDR(m.WgHost) if err != nil { - logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error()) return nil } @@ -629,7 +576,6 @@ func (m *MeshNodeCrdt) GetIdentifier() string { ipv6 := m.WgHost[:len(m.WgHost)-4] constituents := strings.Split(ipv6, ":") - logging.Log.WriteInfof(ipv6) constituents = constituents[4:] return strings.Join(constituents, ":") } diff --git a/pkg/automerge/automerge_sync.go b/pkg/automerge/automerge_sync.go index 86ca53c..f1510e6 100644 --- a/pkg/automerge/automerge_sync.go +++ b/pkg/automerge/automerge_sync.go @@ -1,4 +1,4 @@ -package crdt +package automerge import ( "github.com/automerge/automerge-go" @@ -32,6 +32,7 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error { func (a *AutomergeSync) Complete() { logging.Log.WriteInfof("Sync Completed") + a.manager.SaveChanges() } func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { diff --git a/pkg/automerge/automerge_test.go b/pkg/automerge/automerge_test.go index e882930..396b506 100644 --- a/pkg/automerge/automerge_test.go +++ b/pkg/automerge/automerge_test.go @@ -1,4 +1,4 @@ -package crdt +package automerge import ( "slices" diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index ebccf50..69f9306 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -1,4 +1,4 @@ -package crdt +package automerge import ( "fmt" diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 150ff8d..fa2b20c 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -1,4 +1,4 @@ -package crdt +package automerge // Route: Represents a CRDT of the given route type Route struct { diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go new file mode 100644 index 0000000..599224a --- /dev/null +++ b/pkg/crdt/datastore.go @@ -0,0 +1,442 @@ +package crdt + +import ( + "bytes" + "encoding/gob" + "fmt" + "net" + "strings" + "time" + + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/mesh" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type Route struct { + Destination string + Path []string +} + +// GetDestination implements mesh.Route. +func (r *Route) GetDestination() *net.IPNet { + _, ipnet, _ := net.ParseCIDR(r.Destination) + return ipnet +} + +// GetHopCount implements mesh.Route. +func (r *Route) GetHopCount() int { + return len(r.Path) +} + +// GetPath implements mesh.Route. +func (r *Route) GetPath() []string { + return r.Path +} + +type MeshNode struct { + HostEndpoint string + WgEndpoint string + PublicKey string + WgHost string + Timestamp int64 + Routes map[string]Route + Alias string + Description string + Services map[string]string + Type string +} + +// GetHostEndpoint: gets the gRPC endpoint of the node +func (n *MeshNode) GetHostEndpoint() string { + return n.HostEndpoint +} + +// GetPublicKey: gets the public key of the node +func (n *MeshNode) GetPublicKey() (wgtypes.Key, error) { + return wgtypes.ParseKey(n.PublicKey) +} + +// GetWgEndpoint(): get IP and port of the wireguard endpoint +func (n *MeshNode) GetWgEndpoint() string { + return n.WgEndpoint +} + +// GetWgHost: get the IP address of the WireGuard node +func (n *MeshNode) GetWgHost() *net.IPNet { + _, ipnet, _ := net.ParseCIDR(n.WgHost) + return ipnet +} + +// GetTimestamp: get the UNIX time stamp of the ndoe +func (n *MeshNode) GetTimeStamp() int64 { + return n.Timestamp +} + +// GetRoutes: returns the routes that the nodes provides +func (n *MeshNode) GetRoutes() []mesh.Route { + routes := make([]mesh.Route, len(n.Routes)) + + for index, route := range lib.MapValues(n.Routes) { + routes[index] = &Route{ + Destination: route.Destination, + Path: route.Path, + } + } + + return routes +} + +// GetIdentifier: returns the identifier of the node +func (m *MeshNode) GetIdentifier() string { + ipv6 := m.WgHost[:len(m.WgHost)-4] + + constituents := strings.Split(ipv6, ":") + constituents = constituents[4:] + return strings.Join(constituents, ":") +} + +// GetDescription: returns the description for this node +func (n *MeshNode) GetDescription() string { + return n.Description +} + +// GetAlias: associates the node with an alias. Potentially used +// for DNS and so forth. +func (n *MeshNode) GetAlias() string { + return n.Alias +} + +// GetServices: returns a list of services offered by the node +func (n *MeshNode) GetServices() map[string]string { + return n.Services +} + +func (n *MeshNode) GetType() conf.NodeType { + return conf.NodeType(n.Type) +} + +type MeshSnapshot struct { + Nodes map[string]MeshNode +} + +// GetNodes() returns the nodes in the mesh +func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode { + newMap := make(map[string]mesh.MeshNode) + + for key, value := range m.Nodes { + newMap[key] = &MeshNode{ + HostEndpoint: value.HostEndpoint, + PublicKey: value.PublicKey, + WgHost: value.WgHost, + WgEndpoint: value.WgEndpoint, + Timestamp: value.Timestamp, + Routes: value.Routes, + Alias: value.Alias, + Description: value.Description, + Services: value.Services, + Type: value.Type, + } + } + + return newMap +} + +type TwoPhaseStoreMeshManager struct { + MeshId string + IfName string + Client *wgctrl.Client + LastClock uint64 + conf *conf.WgMeshConfiguration + store *TwoPhaseMap[string, MeshNode] +} + +// AddNode() adds a node to the mesh +func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) { + crdt, ok := node.(*MeshNode) + + if !ok { + panic("node must be of type mesh node") + } + + crdt.Routes = make(map[string]Route) + crdt.Services = make(map[string]string) + crdt.Timestamp = time.Now().Unix() + + m.store.Put(crdt.PublicKey, *crdt) +} + +// GetMesh() returns a snapshot of the mesh provided by the mesh provider. +func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) { + return &MeshSnapshot{ + Nodes: m.store.AsMap(), + }, nil +} + +// GetMeshId() returns the ID of the mesh network +func (m *TwoPhaseStoreMeshManager) GetMeshId() string { + return m.MeshId +} + +// Save() saves the mesh network +func (m *TwoPhaseStoreMeshManager) Save() []byte { + snapshot := m.store.Snapshot() + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + + err := enc.Encode(*snapshot) + + if err != nil { + logging.Log.WriteInfof(err.Error()) + } + + return buf.Bytes() +} + +// Load() loads a mesh network +func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error { + buf := bytes.NewBuffer(bs) + + dec := gob.NewDecoder(buf) + + var snapshot TwoPhaseMapSnapshot[string, MeshNode] + err := dec.Decode(&snapshot) + m.store.Merge(snapshot) + return err +} + +// GetDevice() get the device corresponding with the mesh +func (m *TwoPhaseStoreMeshManager) GetDevice() (*wgtypes.Device, error) { + dev, err := m.Client.Device(m.IfName) + + if err != nil { + return nil, err + } + + return dev, nil +} + +// HasChanges returns true if we have changes since last time we synced +func (m *TwoPhaseStoreMeshManager) HasChanges() bool { + clockValue := m.store.GetClock() + return clockValue != m.LastClock +} + +// Record that we have changes and save the corresponding changes +func (m *TwoPhaseStoreMeshManager) SaveChanges() { + clockValue := m.store.GetClock() + m.LastClock = clockValue +} + +// UpdateTimeStamp: update the timestamp of the given node +func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + node.Timestamp = time.Now().Unix() + m.store.Put(nodeId, node) + return nil +} + +// AddRoutes: adds routes to the given node +func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + if len(routes) == 0 { + return nil + } + + node := m.store.Get(nodeId) + + for _, route := range routes { + node.Routes[route.GetDestination().String()] = Route{ + Destination: route.GetDestination().String(), + Path: route.GetPath(), + } + } + + m.store.Put(nodeId, node) + return nil +} + +// DeleteRoutes: deletes the routes from the node +func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + if len(routes) == 0 { + return nil + } + + node := m.store.Get(nodeId) + + for _, route := range routes { + delete(node.Routes, route) + } + + return nil +} + +// GetSyncer: returns the automerge syncer for sync +func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer { + return NewTwoPhaseSyncer(m) +} + +// GetNode get a particular not within the mesh +func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) { + if !m.store.Contains(nodeId) { + return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + return &node, nil +} + +// NodeExists: returns true if a particular node exists false otherwise +func (m *TwoPhaseStoreMeshManager) NodeExists(nodeId string) bool { + return m.store.Contains(nodeId) +} + +// SetDescription: sets the description of this automerge data type +func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + node.Description = description + + m.store.Put(nodeId, node) + return nil +} + +// SetAlias: set the alias of the nodeId +func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + node.Description = alias + + m.store.Put(nodeId, node) + return nil +} + +// AddService: adds the service to the given node +func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + node.Services[key] = value + m.store.Put(nodeId, node) + return nil +} + +// RemoveService: removes the service form the node. throws an error if the service does not exist +func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + node := m.store.Get(nodeId) + delete(node.Services, key) + m.store.Put(nodeId, node) + return nil +} + +// Prune: prunes all nodes that have not updated their timestamp in +// pruneAmount seconds +func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error { + return nil +} + +// GetPeers: get a list of contactable peers +func (m *TwoPhaseStoreMeshManager) GetPeers() []string { + nodes := lib.MapValues(m.store.AsMap()) + nodes = lib.Filter(nodes, func(mn MeshNode) bool { + if mn.Type != string(conf.PEER_ROLE) { + return false + } + + return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime) + }) + + return lib.Map(nodes, func(mn MeshNode) string { + return mn.PublicKey + }) +} + +func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) { + if !m.store.Contains(targetNode) { + return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode) + } + + node := m.store.Get(targetNode) + return node.Routes, nil +} + +// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen +func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) { + node, err := m.GetNode(targetNode) + + if err != nil { + return nil, err + } + + routes := make(map[string]mesh.Route) + + // Add routes that the node directly has + for _, route := range node.GetRoutes() { + routes[route.GetDestination().String()] = route + } + + // Work out the other routes in the mesh + for _, node := range m.GetPeers() { + nodeRoutes, err := m.getRoutes(node) + + if err != nil { + return nil, err + } + + for _, route := range nodeRoutes { + otherRoute, ok := routes[route.GetDestination().String()] + + hopCount := route.GetHopCount() + + if node != targetNode { + hopCount += 1 + } + + if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() { + routes[route.GetDestination().String()] = &Route{ + Destination: route.GetDestination().String(), + Path: append(route.GetPath(), m.GetMeshId()), + } + } + } + } + + return routes, nil +} + +// RemoveNode(): remove the node from the mesh +func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error { + if !m.store.Contains(nodeId) { + return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) + } + + m.store.Remove(nodeId) + return nil +} diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go new file mode 100644 index 0000000..5e2ddc6 --- /dev/null +++ b/pkg/crdt/factory.go @@ -0,0 +1,73 @@ +package crdt + +import ( + "fmt" + + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" + "github.com/tim-beatham/wgmesh/pkg/mesh" +) + +type TwoPhaseMapFactory struct{} + +func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { + return &TwoPhaseStoreMeshManager{ + MeshId: params.MeshId, + IfName: params.DevName, + Client: params.Client, + conf: params.Conf, + store: NewTwoPhaseMap[string, MeshNode](params.NodeID), + }, nil +} + +type MeshNodeFactory struct { + Config conf.WgMeshConfiguration +} + +func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { + hostName := f.getAddress(params) + + grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) + + if f.Config.Role == conf.CLIENT_ROLE { + grpcEndpoint = "-" + } + + return &MeshNode{ + HostEndpoint: grpcEndpoint, + PublicKey: params.PublicKey.String(), + WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), + WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), + Routes: make(map[string]Route), + Description: "", + Alias: "", + Type: string(f.Config.Role), + } +} + +// getAddress returns the routable address of the machine. +func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string { + var hostName string = "" + + if params.Endpoint != "" { + hostName = params.Endpoint + } else if len(f.Config.Endpoint) != 0 { + hostName = f.Config.Endpoint + } else { + ipFunc := lib.GetPublicIP + + if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { + ipFunc = lib.GetOutboundIP + } + + ip, err := ipFunc() + + if err != nil { + return "" + } + + hostName = ip.String() + } + + return hostName +} diff --git a/pkg/crdt/g_map.go b/pkg/crdt/g_map.go new file mode 100644 index 0000000..81b916a --- /dev/null +++ b/pkg/crdt/g_map.go @@ -0,0 +1,121 @@ +// crdt is a golang implementation of a crdt +package crdt + +import ( + "sync" +) + +type Bucket[D any] struct { + Vector uint64 + Contents D +} + +// GMap is a set that can only grow in size +type GMap[K comparable, D any] struct { + lock sync.RWMutex + contents map[K]Bucket[D] + getClock func() uint64 +} + +func (g *GMap[K, D]) Put(key K, value D) { + g.lock.Lock() + + clock := g.getClock() + 1 + + g.contents[key] = Bucket[D]{ + Vector: clock, + Contents: value, + } + + g.lock.Unlock() +} + +func (g *GMap[K, D]) Contains(key K) bool { + g.lock.RLock() + + _, ok := g.contents[key] + + g.lock.RUnlock() + + return ok +} + +func (g *GMap[K, D]) put(key K, b Bucket[D]) { + g.lock.Lock() + + if g.contents[key].Vector < b.Vector { + g.contents[key] = b + } + + g.lock.Unlock() +} + +func (g *GMap[K, D]) get(key K) Bucket[D] { + g.lock.RLock() + bucket := g.contents[key] + g.lock.RUnlock() + + return bucket +} + +func (g *GMap[K, D]) Get(key K) D { + return g.get(key).Contents +} + +func (g *GMap[K, D]) Keys() []K { + g.lock.RLock() + + contents := make([]K, len(g.contents)) + index := 0 + + for key := range g.contents { + contents[index] = key + index++ + } + + g.lock.RUnlock() + return contents +} + +func (g *GMap[K, D]) Save() map[K]Bucket[D] { + buckets := make(map[K]Bucket[D]) + g.lock.RLock() + + for key, value := range g.contents { + buckets[key] = value + } + + g.lock.RUnlock() + return buckets +} + +func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] { + buckets := make(map[K]Bucket[D]) + g.lock.RLock() + + for _, key := range keys { + buckets[key] = g.contents[key] + } + + g.lock.RUnlock() + return buckets +} + +func (g *GMap[K, D]) GetClock() map[K]uint64 { + clock := make(map[K]uint64) + g.lock.RLock() + + for key, bucket := range g.contents { + clock[key] = bucket.Vector + } + + g.lock.RUnlock() + return clock +} + +func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] { + return &GMap[K, D]{ + contents: make(map[K]Bucket[D]), + getClock: getClock, + } +} diff --git a/pkg/crdt/two_phase_map.go b/pkg/crdt/two_phase_map.go new file mode 100644 index 0000000..25ec162 --- /dev/null +++ b/pkg/crdt/two_phase_map.go @@ -0,0 +1,208 @@ +package crdt + +import ( + "sync" + + "github.com/tim-beatham/wgmesh/pkg/lib" +) + +type TwoPhaseMap[K comparable, D any] struct { + addMap *GMap[K, D] + removeMap *GMap[K, bool] + vectors map[K]uint64 + processId K + lock sync.RWMutex +} + +type TwoPhaseMapSnapshot[K comparable, D any] struct { + Add map[K]Bucket[D] + Remove map[K]Bucket[bool] +} + +// Contains checks whether the value exists in the map +func (m *TwoPhaseMap[K, D]) Contains(key K) bool { + if !m.addMap.Contains(key) { + return false + } + + addValue := m.addMap.get(key) + + if !m.removeMap.Contains(key) { + return true + } + + removeValue := m.removeMap.get(key) + + return addValue.Vector >= removeValue.Vector +} + +func (m *TwoPhaseMap[K, D]) Get(key K) D { + var result D + + if !m.Contains(key) { + return result + } + + return m.addMap.Get(key) +} + +// Put places the key K in the map +func (m *TwoPhaseMap[K, D]) Put(key K, data D) { + msgSequence := m.incrementClock() + + m.lock.Lock() + + if _, ok := m.vectors[key]; !ok { + m.vectors[key] = msgSequence + } + + m.lock.Unlock() + m.addMap.Put(key, data) +} + +// Remove removes the value from the map +func (m *TwoPhaseMap[K, D]) Remove(key K) { + m.removeMap.Put(key, true) +} + +func (m *TwoPhaseMap[K, D]) Keys() []K { + keys := make([]K, 0) + + addKeys := m.addMap.Keys() + + for _, key := range addKeys { + if !m.Contains(key) { + continue + } + + keys = append(keys, key) + } + + return keys +} + +func (m *TwoPhaseMap[K, D]) AsMap() map[K]D { + theMap := make(map[K]D) + + keys := m.Keys() + + for _, key := range keys { + theMap[key] = m.Get(key) + } + + return theMap +} + +func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] { + return &TwoPhaseMapSnapshot[K, D]{ + Add: m.addMap.Save(), + Remove: m.removeMap.Save(), + } +} + +func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] { + addKeys := lib.MapKeys(state.AddContents) + removeKeys := lib.MapKeys(state.RemoveContents) + + return &TwoPhaseMapSnapshot[K, D]{ + Add: m.addMap.SaveWithKeys(addKeys), + Remove: m.removeMap.SaveWithKeys(removeKeys), + } +} + +type TwoPhaseMapState[K comparable] struct { + AddContents map[K]uint64 + RemoveContents map[K]uint64 +} + +func (m *TwoPhaseMap[K, D]) incrementClock() uint64 { + maxClock := uint64(0) + m.lock.Lock() + + for _, value := range m.vectors { + maxClock = max(maxClock, value) + } + + m.vectors[m.processId] = maxClock + 1 + m.lock.Unlock() + return maxClock +} + +func (m *TwoPhaseMap[K, D]) GetClock() uint64 { + maxClock := uint64(0) + m.lock.RLock() + + for _, value := range m.vectors { + maxClock = max(maxClock, value) + } + + m.lock.RUnlock() + return maxClock +} + +// GetState: get the current vector clock of the add and remove +// map +func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] { + addContents := m.addMap.GetClock() + removeContents := m.removeMap.GetClock() + + return &TwoPhaseMapState[K]{ + AddContents: addContents, + RemoveContents: removeContents, + } +} + +func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] { + mapState := &TwoPhaseMapState[K]{ + AddContents: make(map[K]uint64), + RemoveContents: make(map[K]uint64), + } + + for key, value := range state.AddContents { + otherValue, ok := m.AddContents[key] + + if !ok || otherValue < value { + mapState.AddContents[key] = value + } + } + + for key, value := range state.AddContents { + otherValue, ok := m.RemoveContents[key] + + if !ok || otherValue < value { + mapState.RemoveContents[key] = value + } + } + + return mapState +} + +func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { + m.lock.Lock() + + for key, value := range snapshot.Add { + m.addMap.put(key, value) + m.vectors[key] = max(value.Vector, m.vectors[key]) + } + + for key, value := range snapshot.Remove { + m.removeMap.put(key, value) + m.vectors[key] = max(value.Vector, m.vectors[key]) + } + + m.lock.Unlock() +} + +// NewTwoPhaseMap: create a new two phase map. Consists of two maps +// a grow map and a remove map. If both timestamps equal then favour keeping +// it in the map +func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] { + m := TwoPhaseMap[K, D]{ + vectors: make(map[K]uint64), + processId: processId, + } + + m.addMap = NewGMap[K, D](m.incrementClock) + m.removeMap = NewGMap[K, bool](m.incrementClock) + return &m +} diff --git a/pkg/crdt/two_phase_map_syncer.go b/pkg/crdt/two_phase_map_syncer.go new file mode 100644 index 0000000..77645bd --- /dev/null +++ b/pkg/crdt/two_phase_map_syncer.go @@ -0,0 +1,145 @@ +package crdt + +import ( + "bytes" + "encoding/gob" + + logging "github.com/tim-beatham/wgmesh/pkg/log" +) + +type SyncState int + +const ( + PREPARE SyncState = iota + PRESENT + EXCHANGE + MERGE + FINISHED +) + +// TwoPhaseSyncer is a type to sync a TwoPhase data store +type TwoPhaseSyncer struct { + manager *TwoPhaseStoreMeshManager + generateMessageFSM SyncFSM + state SyncState + mapState *TwoPhaseMapState[string] + peerMsg []byte +} + +type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool) + +func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) { + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + + err := enc.Encode(*syncer.mapState) + + if err != nil { + logging.Log.WriteInfof(err.Error()) + } + + syncer.IncrementState() + return buffer.Bytes(), true +} + +func present(syncer *TwoPhaseSyncer) ([]byte, bool) { + if syncer.peerMsg == nil { + panic("peer msg is nil") + } + + var recvBuffer = bytes.NewBuffer(syncer.peerMsg) + dec := gob.NewDecoder(recvBuffer) + + var mapState TwoPhaseMapState[string] + err := dec.Decode(&mapState) + + if err != nil { + logging.Log.WriteInfof(err.Error()) + } + + difference := syncer.mapState.Difference(&mapState) + + var sendBuffer bytes.Buffer + enc := gob.NewEncoder(&sendBuffer) + enc.Encode(*difference) + + syncer.IncrementState() + return sendBuffer.Bytes(), true +} + +func exchange(syncer *TwoPhaseSyncer) ([]byte, bool) { + if syncer.peerMsg == nil { + panic("peer msg is nil") + } + + var recvBuffer = bytes.NewBuffer(syncer.peerMsg) + dec := gob.NewDecoder(recvBuffer) + + var mapState TwoPhaseMapState[string] + dec.Decode(&mapState) + + snapshot := syncer.manager.store.SnapShotFromState(&mapState) + + var sendBuffer bytes.Buffer + enc := gob.NewEncoder(&sendBuffer) + enc.Encode(*snapshot) + + syncer.IncrementState() + return sendBuffer.Bytes(), true +} + +func merge(syncer *TwoPhaseSyncer) ([]byte, bool) { + if syncer.peerMsg == nil { + panic("peer msg is nil") + } + + var recvBuffer = bytes.NewBuffer(syncer.peerMsg) + dec := gob.NewDecoder(recvBuffer) + + var snapshot TwoPhaseMapSnapshot[string, MeshNode] + dec.Decode(&snapshot) + + syncer.manager.store.Merge(snapshot) + + return nil, false +} + +func (t *TwoPhaseSyncer) IncrementState() { + t.state = min(t.state+1, FINISHED) +} + +func (t *TwoPhaseSyncer) GenerateMessage() ([]byte, bool) { + fsmFunc, ok := t.generateMessageFSM[t.state] + + if !ok { + panic("state not handled") + } + + return fsmFunc(t) +} + +func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error { + t.peerMsg = msg + return nil +} + +func (t *TwoPhaseSyncer) Complete() { + logging.Log.WriteInfof("SYNC COMPLETED") + t.manager.SaveChanges() +} + +func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer { + var generateMessageFsm SyncFSM = SyncFSM{ + PREPARE: prepare, + PRESENT: present, + EXCHANGE: exchange, + MERGE: merge, + } + + return &TwoPhaseSyncer{ + manager: manager, + state: PREPARE, + mapState: manager.store.GenerateMessage(), + generateMessageFSM: generateMessageFsm, + } +} diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 92de0e4..9e54b44 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -1,9 +1,9 @@ package ctrlserver import ( - crdt "github.com/tim-beatham/wgmesh/pkg/automerge" "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conn" + "github.com/tim-beatham/wgmesh/pkg/crdt" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" @@ -28,8 +28,8 @@ type NewCtrlServerParams struct { // operation failed func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) - meshFactory := crdt.CrdtProviderFactory{} - nodeFactory := crdt.MeshNodeFactory{ + meshFactory := &crdt.TwoPhaseMapFactory{} + nodeFactory := &crdt.MeshNodeFactory{ Config: *params.Conf, } idGenerator := &lib.IDNameGenerator{} @@ -41,8 +41,8 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { meshManagerParams := &mesh.NewMeshManagerParams{ Conf: *params.Conf, Client: params.Client, - MeshProvider: &meshFactory, - NodeFactory: &nodeFactory, + MeshProvider: meshFactory, + NodeFactory: nodeFactory, IdGenerator: idGenerator, IPAllocator: ipAllocator, InterfaceManipulator: interfaceManipulator, diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 145848f..ff51311 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -226,8 +226,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } cfg := wgtypes.Config{ - Peers: peerConfigs, - ReplacePeers: true, + Peers: peerConfigs, } dev, err := mesh.GetDevice() diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index e099bb0..9573d56 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -146,6 +146,7 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { Conf: m.conf, Client: m.Client, MeshId: meshId, + NodeID: m.HostParameters.GetPublicKey(), }) if err != nil { @@ -183,6 +184,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { Conf: m.conf, Client: m.Client, MeshId: params.MeshId, + NodeID: m.HostParameters.GetPublicKey(), }) if err != nil { @@ -214,11 +216,6 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider { // GetPublicKey: Gets the public key of the WireGuard mesh func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key { - if s.conf.StubWg { - zeroedKey := make([]byte, wgtypes.KeyLen) - return (*wgtypes.Key)(zeroedKey) - } - key := s.HostParameters.PrivateKey.PublicKey() return &key } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 9f461fc..251805a 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -159,6 +159,7 @@ type MeshProviderFactoryParams struct { Port int Conf *conf.WgMeshConfiguration Client *wgctrl.Client + NodeID string } // MeshProviderFactory creates an instance of a mesh provider diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index a878694..be7d836 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -45,6 +45,8 @@ func (s *SyncerImpl) Sync(meshId string) error { publicKey := s.manager.GetPublicKey() + logging.Log.WriteInfof(publicKey.String()) + nodeNames := s.manager.GetMesh(meshId).GetPeers() neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) @@ -87,11 +89,6 @@ func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) - - // Check if any changes have occurred and trigger callbacks - // if changes have occurred. - // return s.manager.GetMonitor().Trigger() - s.manager.GetMesh(meshId).SaveChanges() return nil } diff --git a/pkg/timers/timers.go b/pkg/timers/timers.go index 84e1e7c..bbc8430 100644 --- a/pkg/timers/timers.go +++ b/pkg/timers/timers.go @@ -3,10 +3,12 @@ package timer import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" ) func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { timerFunc := func() error { + logging.Log.WriteInfof("Updated Timestamp") return ctrlServer.MeshManager.UpdateTimeStamp() }