Merge pull request #17 from tim-beatham/25-ability-to-aliases

25 ability to aliases
This commit is contained in:
Tim Beatham 2023-11-17 22:20:57 +00:00 committed by GitHub
commit 023565d985
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 657 additions and 43 deletions

View File

@ -13,5 +13,5 @@ func main() {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
apiServer.Run(":8080") apiServer.Run(":40000")
} }

View File

@ -171,6 +171,51 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply) fmt.Println(reply)
} }
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes") "wg-mesh Manipulate WireGuard meshes")
@ -184,6 +229,9 @@ func main() {
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath") queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node") putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
@ -195,8 +243,6 @@ func main() {
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +254,13 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true}) var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -245,10 +298,6 @@ func main() {
})) }))
} }
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() { if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId) getGraph(client, *getGraphMeshId)
} }
@ -268,4 +317,16 @@ func main() {
if putDescriptionCmd.Happened() { if putDescriptionCmd.Happened() {
putDescription(client, *description) putDescription(client, *description)
} }
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
} }

View File

@ -45,8 +45,8 @@ func main() {
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)

View File

@ -37,6 +37,8 @@ func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
Description: meshNode.Description, Description: meshNode.Description,
Routes: meshNode.Routes, Routes: meshNode.Routes,
PublicKey: meshNode.PublicKey, PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias,
Services: meshNode.Services,
} }
} }

View File

@ -1,6 +1,7 @@
package api package api
type SmegNode struct { type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@ -8,6 +9,7 @@ type SmegNode struct {
Description string `json:"description"` Description string `json:"description"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Routes []string `json:"routes"` Routes []string `json:"routes"`
Services map[string]string `json:"services"`
} }
type SmegMesh struct { type SmegMesh struct {

View File

@ -20,7 +20,6 @@ import (
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string MeshId string
IfName string IfName string
NodeId string
Client *wgctrl.Client Client *wgctrl.Client
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
@ -35,8 +34,9 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]interface{})
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
@ -178,6 +178,72 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -336,6 +402,20 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -348,6 +428,8 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp, Timestamp: node.Timestamp,
Routes: node.Routes, Routes: node.Routes,
Description: node.Description, Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
} }
} }

View File

@ -36,6 +36,8 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: map[string]interface{}{},
Description: "",
Alias: "",
} }
} }

View File

@ -8,7 +8,9 @@ type MeshNodeCrdt struct {
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"` Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -18,6 +18,8 @@ type MeshNode struct {
Timestamp int64 Timestamp int64
Routes []string Routes []string
Description string Description string
Alias string
Services map[string]string
} }
// Represents a WireGuard Mesh // Represents a WireGuard Mesh

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -34,6 +34,11 @@ type JoinMeshArgs struct {
Endpoint string Endpoint string
} }
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct { type GetMeshReply struct {
Nodes []ctrlserver.MeshNode Nodes []ctrlserver.MeshNode
} }
@ -57,6 +62,9 @@ type MeshIpc interface {
GetDOT(meshId string, reply *string) error GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(service string, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -25,11 +25,15 @@ type MeshManager interface {
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Prune() error Prune() error
Close() error Close() error
GetMonitor() MeshMonitor
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -46,6 +50,38 @@ type MeshManagerImpl struct {
idGenerator lib.IdGenerator idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
if err != nil {
return err
}
}
return nil
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -296,6 +332,18 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
return nil return nil
} }
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if err != nil {
return err
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
@ -359,7 +407,7 @@ type NewMeshManagerParams struct {
} }
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} hostParams := HostParameters{}
switch params.Conf.Endpoint { switch params.Conf.Endpoint {
@ -390,5 +438,11 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m return m
} }

View File

@ -22,7 +22,7 @@ func getMeshConfiguration() *conf.WgMeshConfiguration {
} }
} }
func getMeshManager() *MeshManagerImpl { func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{ manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(), Conf: *getMeshConfiguration(),
Client: nil, Client: nil,
@ -51,7 +51,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`) t.Fatal(`meshId should not be empty`)
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if !exists { if !exists {
t.Fatal(`mesh was not created when it should be`) t.Fatal(`mesh was not created when it should be`)
@ -190,7 +190,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
t.Fatalf("%s", err.Error()) t.Fatalf("%s", err.Error())
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if exists { if exists {
t.Fatalf(`expected mesh to have been deleted`) t.Fatalf(`expected mesh to have been deleted`)

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -21,6 +21,16 @@ type MeshNodeStub struct {
description string description string
} }
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string { func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint return m.hostEndpoint
} }
@ -66,6 +76,21 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented")
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented")
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented") panic("unimplemented")
@ -171,6 +196,26 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") panic("unimplemented")

View File

@ -4,6 +4,7 @@ package mesh
import ( import (
"net" "net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
@ -28,6 +29,51 @@ type MeshNode interface {
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
GetDescription() string GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() {
return false
}
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
if node1Pub != node2Pub {
return false
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -70,6 +116,12 @@ type MeshProvider interface {
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error

View File

@ -31,6 +31,8 @@ type QueryNode struct {
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestmap"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []string `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -76,6 +78,9 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = node.GetRoutes()
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
return queryNode return queryNode
} }

View File

@ -117,6 +117,11 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) mesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh() meshSnapshot, err := mesh.GetMesh()
if err != nil { if err != nil {
@ -145,6 +150,8 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
Timestamp: node.GetTimeStamp(), Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(), Routes: node.GetRoutes(),
Description: node.GetDescription(), Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
} }
nodes[i] = node nodes[i] = node
@ -202,6 +209,39 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -30,6 +30,11 @@ type SyncerImpl struct {
// Sync: Sync random nodes // Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
logging.Log.WriteInfof("UPDATING WG CONF") logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig() err := s.manager.ApplyConfig()
@ -37,23 +42,13 @@ func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("Failed to update config %w", err) logging.Log.WriteInfof("Failed to update config %w", err)
} }
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
theMesh := s.manager.GetMesh(meshId) theMesh := s.manager.GetMesh(meshId)
if theMesh == nil { if theMesh == nil {
return errors.New("the provided mesh does not exist") return errors.New("the provided mesh does not exist")
} }
snapshot, err := theMesh.GetMesh() snapshot, _ := theMesh.GetMesh()
if err != nil {
return err
}
nodes := snapshot.GetNodes() nodes := snapshot.GetNodes()
if len(nodes) <= 1 { if len(nodes) <= 1 {
@ -107,16 +102,19 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait() waitGroup.Wait()
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before)) logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
return nil
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
return s.manager.GetMonitor().Trigger()
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() { for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId) err := s.Sync(meshId)
if err != nil { if err != nil {