diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 8925b7d..2eab14b 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -171,6 +171,20 @@ func putDescription(client *ipcRpc.Client, description string) { fmt.Println(reply) } +// putAlias: puts an alias for the node +func putAlias(client *ipcRpc.Client, alias string) { + var reply string + + err := client.Call("IpcHandler.PutAlias", &alias, &reply) + + if err != nil { + fmt.Println(err.Error()) + return + } + + fmt.Println(reply) +} + func main() { parser := argparse.NewParser("wg-mesh", "wg-mesh Manipulate WireGuard meshes") @@ -184,6 +198,7 @@ func main() { leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath") putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node") + putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node") var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) @@ -208,6 +223,8 @@ func main() { var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true}) + var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true}) + err := parser.Parse(os.Args) if err != nil { @@ -245,10 +262,6 @@ func main() { })) } - // if getMeshCmd.Happened() { - // getMesh(client, *getMeshId) - // } - if getGraphCmd.Happened() { getGraph(client, *getGraphMeshId) } @@ -268,4 +281,8 @@ func main() { if putDescriptionCmd.Happened() { putDescription(client, *description) } + + if putAliasCmd.Happened() { + putAlias(client, *alias) + } } diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index facc6bf..6261545 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -45,8 +45,8 @@ func main() { SyncProvider: &syncProvider, Client: client, } - ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) + ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) syncProvider.Server = ctrlServer syncRequester := sync.NewSyncRequester(ctrlServer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 761a374..f6ee57d 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -178,6 +178,26 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro return err } +func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { + node, err := m.doc.Path("nodes").Map().Get(nodeId) + + if err != nil { + return err + } + + if node.Kind() != automerge.KindMap { + return fmt.Errorf("%s does not exist", nodeId) + } + + err = node.Map().Set("alias", alias) + + if err == nil { + logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias) + } + + return err +} + // AddRoutes: adds routes to the specific nodeId func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) @@ -336,6 +356,10 @@ func (m *MeshNodeCrdt) GetIdentifier() string { return strings.Join(constituents, ":") } +func (m *MeshNodeCrdt) GetAlias() string { + return m.Alias +} + func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { nodes := make(map[string]mesh.MeshNode) @@ -348,6 +372,7 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { Timestamp: node.Timestamp, Routes: node.Routes, Description: node.Description, + Alias: node.Alias, } } diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index 188d5d0..71148a6 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -35,7 +35,9 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), // Always set the routes as empty. // Routes handled by external component - Routes: map[string]interface{}{}, + Routes: map[string]interface{}{}, + Description: "", + Alias: "", } } diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 60315fe..01b6b69 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -8,6 +8,7 @@ type MeshNodeCrdt struct { WgHost string `automerge:"wgHost"` Timestamp int64 `automerge:"timestamp"` Routes map[string]interface{} `automerge:"routes"` + Alias string `automerge:"alias"` Description string `automerge:"description"` } diff --git a/pkg/hosts/hosts.go b/pkg/hosts/hosts.go new file mode 100644 index 0000000..735b91c --- /dev/null +++ b/pkg/hosts/hosts.go @@ -0,0 +1,131 @@ +// hosts: utility for modifying the /etc/hosts file +package hosts + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "os" + "strings" +) + +// HOSTS_FILE is the hosts file location +const HOSTS_FILE = "/etc/hosts" + +const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS" +const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END" + +// Generic interface to manipulate /etc/hosts file +type HostsManipulator interface { + // AddrAddr associates an aliasd with a given IP address + AddAddr(ipAddr net.IP, alias string) + // Remove deletes the entry from /etc/hosts + Remove(alias string) + // Writes the changes to /etc/hosts file + Write() error +} + +type HostsManipulatorImpl struct { + hosts map[string]net.IP + meshid string +} + +// AddAddr implements HostsManipulator. +func (m *HostsManipulatorImpl) AddAddr(ipAddr net.IP, alias string) { + m.hosts[alias] = ipAddr +} + +// Remove implements HostsManipulator. +func (m *HostsManipulatorImpl) Remove(alias string) { + delete(m.hosts, alias) +} + +type HostsEntry struct { + Alias string + Ip net.IP +} + +func (m *HostsManipulatorImpl) removeHosts() string { + hostsFile, err := os.ReadFile(HOSTS_FILE) + + if err != nil { + return "" + } + + var contents strings.Builder + + scanner := bufio.NewScanner(bytes.NewReader(hostsFile)) + + hostsSection := false + + for scanner.Scan() { + line := scanner.Text() + + if err == io.EOF { + break + } else if err != nil { + return "" + } + + if !hostsSection && strings.Contains(line, DOMAIN_HEADER+m.meshid) { + hostsSection = true + } + + if !hostsSection { + contents.WriteString(line + "\n") + } + + if hostsSection && strings.Contains(line, DOMAIN_TRAILER+m.meshid) { + hostsSection = false + } + } + + if scanner.Err() != nil && scanner.Err() != io.EOF { + return "" + } + + return contents.String() +} + +// Write implements HostsManipulator +func (m *HostsManipulatorImpl) Write() error { + contents := m.removeHosts() + + var nextHosts strings.Builder + nextHosts.WriteString(contents) + + nextHosts.WriteString(DOMAIN_HEADER + m.meshid + "\n") + + for alias, ip := range m.hosts { + nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", ip.String(), alias)) + } + + nextHosts.WriteString(DOMAIN_TRAILER + m.meshid + "\n") + return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644) +} + +// parseLine parses a line in the /etc/hosts file +func parseLine(line string) (*HostsEntry, error) { + fields := strings.Fields(line) + + if len(fields) != 2 { + return nil, fmt.Errorf("expected entry length of 2 was %d", len(fields)) + } + + ipAddr := fields[0] + alias := fields[1] + + ip := net.ParseIP(ipAddr) + + if ip == nil { + return nil, fmt.Errorf("failed to parse ip for %s", alias) + } + + return &HostsEntry{Ip: ip, Alias: alias}, nil +} + +func NewHostsManipulator(meshId string) HostsManipulator { + return &HostsManipulatorImpl{hosts: make(map[string]net.IP), meshid: meshId} +} diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index e816b5b..54d6b0f 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -57,6 +57,7 @@ type MeshIpc interface { GetDOT(meshId string, reply *string) error Query(query QueryMesh, reply *string) error PutDescription(description string, reply *string) error + PutAlias(alias string, reply *string) error } const SockAddr = "/tmp/wgmesh_ipc.sock" diff --git a/pkg/mesh/alias.go b/pkg/mesh/alias.go new file mode 100644 index 0000000..f54324d --- /dev/null +++ b/pkg/mesh/alias.go @@ -0,0 +1,15 @@ +package mesh + +import "github.com/tim-beatham/wgmesh/pkg/hosts" + +func AddAliases(meshid string, snapshot MeshSnapshot) { + hosts := hosts.NewHostsManipulator(meshid) + + for _, node := range snapshot.GetNodes() { + if node.GetAlias() != "" { + hosts.AddAddr(node.GetWgHost().IP, node.GetAlias()) + } + } + + hosts.Write() +} diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index bb412b4..d0ae9a7 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -25,11 +25,13 @@ type MeshManager interface { GetSelf(meshId string) (MeshNode, error) ApplyConfig() error SetDescription(description string) error + SetAlias(alias string) error UpdateTimeStamp() error GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider Prune() error Close() error + GetMonitor() MeshMonitor } type MeshManagerImpl struct { @@ -46,6 +48,12 @@ type MeshManagerImpl struct { idGenerator lib.IdGenerator ipAllocator ip.IPAllocator interfaceManipulator wg.WgInterfaceManipulator + Monitor MeshMonitor +} + +// GetMonitor implements MeshManager. +func (m *MeshManagerImpl) GetMonitor() MeshMonitor { + return m.Monitor } // Prune implements MeshManager. @@ -296,6 +304,18 @@ func (s *MeshManagerImpl) SetDescription(description string) error { return nil } +// SetAlias implements MeshManager. +func (s *MeshManagerImpl) SetAlias(alias string) error { + for _, mesh := range s.Meshes { + err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) + + if err != nil { + return err + } + } + return nil +} + // UpdateTimeStamp updates the timestamp of this node in all meshes func (s *MeshManagerImpl) UpdateTimeStamp() error { for _, mesh := range s.Meshes { @@ -359,7 +379,7 @@ type NewMeshManagerParams struct { } // Creates a new instance of a mesh manager with the given parameters -func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { +func NewMeshManager(params *NewMeshManagerParams) MeshManager { hostParams := HostParameters{} switch params.Conf.Endpoint { @@ -390,5 +410,8 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { m.idGenerator = params.IdGenerator m.ipAllocator = params.IPAllocator m.interfaceManipulator = params.InterfaceManipulator + + m.Monitor = NewMeshMonitor() + m.Monitor.AddCallback(AddAliases) return m } diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index 65105b1..ad2a14a 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -22,7 +22,7 @@ func getMeshConfiguration() *conf.WgMeshConfiguration { } } -func getMeshManager() *MeshManagerImpl { +func getMeshManager() MeshManager { manager := NewMeshManager(&NewMeshManagerParams{ Conf: *getMeshConfiguration(), Client: nil, diff --git a/pkg/mesh/monitor.go b/pkg/mesh/monitor.go new file mode 100644 index 0000000..a0a098c --- /dev/null +++ b/pkg/mesh/monitor.go @@ -0,0 +1,28 @@ +package mesh + +type OnChange = func(string, MeshSnapshot) + +type MeshMonitor interface { + AddCallback(cb OnChange) + Trigger(meshid string, m MeshSnapshot) +} + +type MeshMonitorImpl struct { + callbacks []OnChange +} + +func (m *MeshMonitorImpl) Trigger(meshid string, snapshot MeshSnapshot) { + for _, cb := range m.callbacks { + cb(meshid, snapshot) + } +} + +func (m *MeshMonitorImpl) AddCallback(cb OnChange) { + m.callbacks = append(m.callbacks, cb) +} + +func NewMeshMonitor() MeshMonitor { + return &MeshMonitorImpl{ + callbacks: make([]OnChange, 0), + } +} diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index c328c76..49b3900 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -21,6 +21,11 @@ type MeshNodeStub struct { description string } +// GetAlias implements MeshNode. +func (*MeshNodeStub) GetAlias() string { + panic("unimplemented") +} + func (m *MeshNodeStub) GetHostEndpoint() string { return m.hostEndpoint } @@ -66,6 +71,11 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// SetAlias implements MeshProvider. +func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { + panic("unimplemented") +} + // RemoveRoutes implements MeshProvider. func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { panic("unimplemented") @@ -171,6 +181,16 @@ type MeshManagerStub struct { meshes map[string]MeshProvider } +// GetMonitor implements MeshManager. +func (*MeshManagerStub) GetMonitor() MeshMonitor { + panic("unimplemented") +} + +// SetAlias implements MeshManager. +func (*MeshManagerStub) SetAlias(alias string) error { + panic("unimplemented") +} + // Close implements MeshManager. func (*MeshManagerStub) Close() error { panic("unimplemented") diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 5543820..6ab174d 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -28,6 +28,9 @@ type MeshNode interface { GetIdentifier() string // GetDescription: returns the description for this node GetDescription() string + // GetAlias: associates the node with an alias. Potentially used + // for DNS and so forth. + GetAlias() string } type MeshSnapshot interface { @@ -70,6 +73,8 @@ type MeshProvider interface { GetSyncer() MeshSyncer // SetDescription: sets the description of this automerge data type SetDescription(nodeId string, description string) error + // SetAlias: set the alias of the nodeId + SetAlias(nodeId string, alias string) error // Prune: prunes all nodes that have not updated their timestamp in // pruneAmount seconds Prune(pruneAmount int) error diff --git a/pkg/query/query.go b/pkg/query/query.go index 0978f08..16bdc00 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -31,6 +31,7 @@ type QueryNode struct { Timestamp int64 `json:"timestmap"` Description string `json:"description"` Routes []string `json:"routes"` + Alias string `json:"alias"` } func (m *QueryError) Error() string { @@ -76,6 +77,8 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { queryNode.Timestamp = node.GetTimeStamp() queryNode.Routes = node.GetRoutes() queryNode.Description = node.GetDescription() + queryNode.Alias = node.GetAlias() + return queryNode } diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 08a38b4..337cac3 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -202,6 +202,17 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error { return nil } +func (n *IpcHandler) PutAlias(alias string, reply *string) error { + err := n.Server.GetMeshManager().SetAlias(alias) + + if err != nil { + return err + } + + *reply = fmt.Sprintf("Set alias to %s", alias) + return nil +} + type RobinIpcParams struct { CtrlServer ctrlserver.CtrlServer } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 834317b..040d13a 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -30,6 +30,11 @@ type SyncerImpl struct { // Sync: Sync random nodes func (s *SyncerImpl) Sync(meshId string) error { + if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { + logging.Log.WriteInfof("No changes for %s", meshId) + return nil + } + logging.Log.WriteInfof("UPDATING WG CONF") err := s.manager.ApplyConfig() @@ -37,11 +42,6 @@ func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof("Failed to update config %w", err) } - if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { - logging.Log.WriteInfof("No changes for %s", meshId) - return nil - } - theMesh := s.manager.GetMesh(meshId) if theMesh == nil { @@ -50,6 +50,8 @@ func (s *SyncerImpl) Sync(meshId string) error { snapshot, err := theMesh.GetMesh() + s.manager.GetMonitor().Trigger(meshId, snapshot) + if err != nil { return err } @@ -111,6 +113,14 @@ func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) + + newMesh := s.manager.GetMesh(meshId) + snapshot, err = newMesh.GetMesh() + + if err != nil { + return err + } + return nil }