diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index dbd58e2..f396e3b 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -216,6 +216,23 @@ func deleteService(client *ipcRpc.Client, service string) { fmt.Println(reply) } +func getNode(client *ipcRpc.Client, nodeId, meshId string) { + var reply string + args := &ipc.GetNodeArgs{ + NodeId: nodeId, + MeshId: meshId, + } + + err := client.Call("IpcHandler.GetNode", &args, &reply) + + if err != nil { + fmt.Println(err.Error()) + return + } + + fmt.Println(reply) +} + func main() { parser := argparse.NewParser("wg-mesh", "wg-mesh Manipulate WireGuard meshes") @@ -232,6 +249,7 @@ func main() { putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node") setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") + getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) @@ -261,6 +279,9 @@ func main() { var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true}) + var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true}) + var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true}) + err := parser.Parse(os.Args) if err != nil { @@ -329,4 +350,8 @@ func main() { if deleteServiceCmd.Happened() { deleteService(client, *deleteServiceKey) } + + if getNodeCmd.Happened() { + getNode(client, *getNodeNodeId, *getNodeMeshId) + } } diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index 6261545..e97f378 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -1,7 +1,8 @@ package main import ( - "log" + "net/http" + _ "net/http/pprof" "os" "os/signal" @@ -35,6 +36,12 @@ func main() { return } + if conf.Profile { + go func() { + http.ListenAndServe("localhost:6060", nil) + }() + } + var robinRpc robin.WgRpc var robinIpc robin.IpcHandler var syncProvider sync.SyncServiceImpl @@ -65,7 +72,7 @@ func main() { return } - log.Println("Running IPC Handler") + logging.Log.WriteInfof("Running IPC Handler") go ipc.RunIpcHandler(&robinIpc) go syncScheduler.Run() diff --git a/go.mod b/go.mod index d1ef671..452ea64 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,12 @@ go 1.21.3 require ( github.com/akamensky/argparse v1.4.0 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 + github.com/gin-gonic/gin v1.9.1 github.com/google/uuid v1.3.0 github.com/jmespath/go-jmespath v0.4.0 + github.com/jsimonetti/rtnetlink v1.3.5 github.com/sirupsen/logrus v1.9.3 + golang.org/x/sys v0.14.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 google.golang.org/grpc v1.58.1 google.golang.org/protobuf v1.31.0 @@ -19,7 +22,6 @@ require ( github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/gin-gonic/gin v1.9.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect @@ -27,7 +29,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/josharian/native v1.1.0 // indirect - github.com/jsimonetti/rtnetlink v1.3.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect @@ -40,12 +41,10 @@ require ( github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect - github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.13.0 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/sync v0.3.0 // indirect - golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 8406d6a..8131ba4 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -18,12 +18,14 @@ import ( // CrdtMeshManager manages nodes in the crdt mesh type CrdtMeshManager struct { - MeshId string - IfName string - Client *wgctrl.Client - doc *automerge.Doc - LastHash automerge.ChangeHash - conf *conf.WgMeshConfiguration + MeshId string + IfName string + Client *wgctrl.Client + doc *automerge.Doc + LastHash automerge.ChangeHash + conf *conf.WgMeshConfiguration + cache *MeshCrdt + lastCacheHash automerge.ChangeHash } func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { @@ -40,9 +42,31 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) } +func (c *CrdtMeshManager) GetNodeIds() []string { + keys, _ := c.doc.Path("nodes").Map().Keys() + return keys +} + // GetMesh(): Converts the document into a struct func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { - return automerge.As[*MeshCrdt](c.doc.Root()) + changes, err := c.doc.Changes(c.lastCacheHash) + + if err != nil { + return nil, err + } + + if c.cache == nil || len(changes) > 3 { + c.lastCacheHash = c.LastHash + cache, err := automerge.As[*MeshCrdt](c.doc.Root()) + + if err != nil { + return nil, err + } + + c.cache = cache + } + + return c.cache, nil } // GetMeshId returns the meshid of the mesh @@ -82,13 +106,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro manager.IfName = params.DevName manager.Client = params.Client manager.conf = ¶ms.Conf + manager.cache = nil return &manager, nil } -// GetNode: returns a mesh node crdt.Close releases resources used by a Client. -func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { +// NodeExists: returns true if the node exists. Returns false +func (m *CrdtMeshManager) NodeExists(key string) bool { + node, err := m.doc.Path("nodes").Map().Get(key) + return node.Kind() == automerge.KindMap && err != nil +} + +func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { node, err := m.doc.Path("nodes").Map().Get(endpoint) + if node.Kind() != automerge.KindMap { + return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type") + } + if err != nil { return nil, err } diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 26e1dfb..ad68721 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -49,6 +49,10 @@ type WgMeshConfiguration struct { Timeout int `yaml:"timeout"` // PruneTime number of seconds before we consider the 'node' as dead PruneTime int `yaml:"pruneTime"` + // Profile whether or not to include a http server that profiles the code + Profile bool `yaml:"profile"` + // StubWg whether or not to stub the WireGuard types + StubWg bool `yaml:"stubWg"` } func ValidateConfiguration(c *WgMeshConfiguration) error { diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index 8115a3c..ab8fe8c 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -52,6 +52,11 @@ type QueryMesh struct { Query string } +type GetNodeArgs struct { + NodeId string + MeshId string +} + type MeshIpc interface { CreateMesh(args *NewMeshArgs, reply *string) error ListMeshes(name string, reply *ListMeshReply) error @@ -64,6 +69,7 @@ type MeshIpc interface { PutDescription(description string, reply *string) error PutAlias(alias string, reply *string) error PutService(args PutServiceArgs, reply *string) error + GetNode(args GetNodeArgs, reply *string) error DeleteService(service string, reply *string) error } diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 5dc86c3..ed78b9a 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -34,6 +34,7 @@ type MeshManager interface { Prune() error Close() error GetMonitor() MeshMonitor + GetNode(string, string) MeshNode } type MeshManagerImpl struct { @@ -79,6 +80,22 @@ func (m *MeshManagerImpl) SetService(service string, value string) error { return nil } +func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { + mesh, ok := m.Meshes[meshid] + + if !ok { + return nil + } + + node, err := mesh.GetNode(nodeId) + + if err != nil { + return nil + } + + return node +} + // GetMonitor implements MeshManager. func (m *MeshManagerImpl) GetMonitor() MeshMonitor { return m.Monitor @@ -117,13 +134,15 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { return "", fmt.Errorf("error creating mesh: %w", err) } - err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ - IfName: devName, - Port: port, - }) + if !m.conf.StubWg { + err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ + IfName: devName, + Port: port, + }) - if err != nil { - return "", fmt.Errorf("error creating mesh: %w", err) + if err != nil { + return "", fmt.Errorf("error creating mesh: %w", err) + } } m.Meshes[meshId] = nodeManager @@ -159,10 +178,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { m.Meshes[params.MeshId] = meshProvider - return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ - IfName: params.DevName, - Port: params.WgPort, - }) + if !m.conf.StubWg { + return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{ + IfName: params.DevName, + Port: params.WgPort, + }) + } + + return nil } // HasChanges returns true if the mesh has changes @@ -195,6 +218,11 @@ func (s *MeshManagerImpl) EnableInterface(meshId string) error { // GetPublicKey: Gets the public key of the WireGuard mesh func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { + if s.conf.StubWg { + zeroedKey := make([]byte, wgtypes.KeyLen) + return (*wgtypes.Key)(zeroedKey), nil + } + mesh, ok := s.Meshes[meshId] if !ok { @@ -216,7 +244,6 @@ type AddSelfParams struct { // WgPort is the WireGuard port to advertise WgPort int // Endpoint is the alias of the machine to send routable packets - // to Endpoint string } @@ -247,16 +274,18 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { Endpoint: params.Endpoint, }) - device, err := mesh.GetDevice() + if !s.conf.StubWg { + device, err := mesh.GetDevice() - if err != nil { - return fmt.Errorf("failed to get device %w", err) - } + if err != nil { + return fmt.Errorf("failed to get device %w", err) + } - err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP)) + err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP)) - if err != nil { - return fmt.Errorf("addSelf: failed to add address to dev %w", err) + if err != nil { + return fmt.Errorf("addSelf: failed to add address to dev %w", err) + } } s.Meshes[params.MeshId].AddNode(node) @@ -277,13 +306,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { return err } - device, err := mesh.GetDevice() + if !s.conf.StubWg { + device, e := mesh.GetDevice() - if err != nil { - return err + if e != nil { + return err + } + + err = s.interfaceManipulator.RemoveInterface(device.Name) } - err = s.interfaceManipulator.RemoveInterface(device.Name) delete(s.Meshes, meshId) return err } @@ -295,15 +327,9 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { return nil, fmt.Errorf("mesh %s does not exist", meshId) } - snapshot, err := meshInstance.GetMesh() + node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint) if err != nil { - return nil, err - } - - node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint] - - if !ok { return nil, errors.New("the node doesn't exist in the mesh") } @@ -317,15 +343,17 @@ func (s *MeshManagerImpl) ApplyConfig() error { return err } - return s.RouteManager.InstallRoutes() + return nil } func (s *MeshManagerImpl) SetDescription(description string) error { for _, mesh := range s.Meshes { - err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) + if mesh.NodeExists(s.HostParameters.HostEndpoint) { + err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) - if err != nil { - return err + if err != nil { + return err + } } } @@ -335,10 +363,12 @@ func (s *MeshManagerImpl) SetDescription(description string) error { // SetAlias implements MeshManager. func (s *MeshManagerImpl) SetAlias(alias string) error { for _, mesh := range s.Meshes { - err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) + if mesh.NodeExists(s.HostParameters.HostEndpoint) { + err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) - if err != nil { - return err + if err != nil { + return err + } } } return nil @@ -347,16 +377,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error { // UpdateTimeStamp updates the timestamp of this node in all meshes func (s *MeshManagerImpl) UpdateTimeStamp() error { for _, mesh := range s.Meshes { - snapshot, err := mesh.GetMesh() - - if err != nil { - return err - } - - _, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint] - - if exists { - err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) + if mesh.NodeExists(s.HostParameters.HostEndpoint) { + err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) if err != nil { return err @@ -375,7 +397,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { return s.Meshes } +// Close the mesh manager func (s *MeshManagerImpl) Close() error { + if s.conf.StubWg { + return nil + } + for _, mesh := range s.Meshes { dev, err := mesh.GetDevice() diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 74fe859..b13e9ac 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -76,6 +76,21 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// GetNodeIds implements MeshProvider. +func (*MeshProviderStub) GetNodeIds() []string { + panic("unimplemented") +} + +// GetNode implements MeshProvider. +func (*MeshProviderStub) GetNode(string) (MeshNode, error) { + panic("unimplemented") +} + +// NodeExists implements MeshProvider. +func (*MeshProviderStub) NodeExists(string) bool { + panic("unimplemented") +} + // AddService implements MeshProvider. func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { panic("unimplemented") @@ -196,6 +211,11 @@ type MeshManagerStub struct { meshes map[string]MeshProvider } +// GetNode implements MeshManager. +func (*MeshManagerStub) GetNode(string, string) MeshNode { + panic("unimplemented") +} + // RemoveService implements MeshManager. func (*MeshManagerStub) RemoveService(service string) error { panic("unimplemented") diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 8013c7a..7413508 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -92,7 +92,7 @@ type MeshSyncer interface { type MeshProvider interface { // AddNode() adds a node to the mesh AddNode(node MeshNode) - // GetMesh() returns a snapshot of the mesh provided by the mesh provider + // GetMesh() returns a snapshot of the mesh provided by the mesh provider. GetMesh() (MeshSnapshot, error) // GetMeshId() returns the ID of the mesh network GetMeshId() string @@ -114,6 +114,10 @@ type MeshProvider interface { RemoveRoutes(nodeId string, route ...string) error // GetSyncer: returns the automerge syncer for sync GetSyncer() MeshSyncer + // GetNode get a particular not within the mesh + GetNode(string) (MeshNode, error) + // NodeExists: returns true if a particular node exists false otherwise + NodeExists(string) bool // SetDescription: sets the description of this automerge data type SetDescription(nodeId string, description string) error // SetAlias: set the alias of the nodeId @@ -125,6 +129,7 @@ type MeshProvider interface { // Prune: prunes all nodes that have not updated their timestamp in // pruneAmount seconds Prune(pruneAmount int) error + GetNodeIds() []string } // HostParameters contains the IDs of a node diff --git a/pkg/query/query.go b/pkg/query/query.go index 5f06d96..52bb3a9 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -53,7 +53,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) { return nil, err } - nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) + nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode) result, err := jmespath.Search(queryParams, nodes) @@ -65,7 +65,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) { return bytes, err } -func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { +func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode { queryNode := new(QueryNode) queryNode.HostEndpoint = node.GetHostEndpoint() pubKey, _ := node.GetPublicKey() diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index c201152..71c58ec 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -2,6 +2,7 @@ package robin import ( "context" + "encoding/json" "errors" "fmt" "strconv" @@ -10,6 +11,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/mesh" + "github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/wgmesh/pkg/rpc" ) @@ -242,6 +244,27 @@ func (n *IpcHandler) DeleteService(service string, reply *string) error { return nil } +func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error { + node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId) + + if node == nil { + *reply = "nil" + return nil + } + + queryNode := query.MeshNodeToQueryNode(node) + + bytes, err := json.Marshal(queryNode) + + if err != nil { + *reply = err.Error() + return nil + } + + *reply = string(bytes) + return nil +} + type RobinIpcParams struct { CtrlServer ctrlserver.CtrlServer } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 155c1ed..ce8b23f 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -1,7 +1,6 @@ package sync import ( - "errors" "math/rand" "sync" "time" @@ -36,24 +35,16 @@ func (s *SyncerImpl) Sync(meshId string) error { } logging.Log.WriteInfof("UPDATING WG CONF") - err := s.manager.ApplyConfig() - if err != nil { - logging.Log.WriteInfof("Failed to update config %w", err) + if s.manager.HasChanges(meshId) { + err := s.manager.ApplyConfig() + + if err != nil { + logging.Log.WriteInfof("Failed to update config %w", err) + } } - theMesh := s.manager.GetMesh(meshId) - - if theMesh == nil { - return errors.New("the provided mesh does not exist") - } - - snapshot, _ := theMesh.GetMesh() - nodes := snapshot.GetNodes() - - if len(nodes) <= 1 { - return nil - } + nodeNames := s.manager.GetMesh(meshId).GetNodeIds() self, err := s.manager.GetSelf(meshId) @@ -61,17 +52,6 @@ func (s *SyncerImpl) Sync(meshId string) error { return err } - excludedNodes := map[string]struct{}{ - self.GetHostEndpoint(): {}, - } - meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes) - - getNames := func(node mesh.MeshNode) string { - return node.GetHostEndpoint() - } - - nodeNames := lib.Map(meshNodes, getNames) - neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) @@ -81,7 +61,7 @@ func (s *SyncerImpl) Sync(meshId string) error { before := time.Now() - if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { + if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { logging.Log.WriteInfof("Sending to random cluster") interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) randomSubset = append(randomSubset, interCluster) @@ -109,7 +89,8 @@ func (s *SyncerImpl) Sync(meshId string) error { // Check if any changes have occurred and trigger callbacks // if changes have occurred. - return s.manager.GetMonitor().Trigger() + // return s.manager.GetMonitor().Trigger() + return nil } // SyncMeshes: Sync all meshes