From 6e201ebaf51d414b8ba86cc34eea70eac3f888d0 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 21 Nov 2023 16:42:49 +0000 Subject: [PATCH 1/3] 24-keepalive-holepunch Nodes acting as peers and nodes acting as clients --- pkg/automerge/automerge.go | 30 ++++++++++++++- pkg/automerge/factory.go | 1 + pkg/automerge/types.go | 1 + pkg/conf/conf.go | 17 +++++++++ pkg/ctrlserver/ctrlserver.go | 2 +- pkg/ipc/ipc.go | 3 ++ pkg/mesh/config.go | 74 +++++++++++++++++++++++++++++++----- pkg/mesh/manager.go | 11 ++++++ pkg/mesh/stub_types.go | 19 +++++---- pkg/mesh/types.go | 11 +++++- pkg/query/query.go | 5 ++- pkg/sync/syncer.go | 2 +- 12 files changed, 155 insertions(+), 21 deletions(-) diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index f7278b2..e00fccc 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -42,8 +42,29 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) } -func (c *CrdtMeshManager) GetNodeIds() []string { +func (c *CrdtMeshManager) isPeer(nodeId string) bool { + node, err := c.doc.Path("nodes").Map().Get(nodeId) + + if err != nil || node.Kind() != automerge.KindMap { + return false + } + + nodeType, err := node.Map().Get("type") + + if err != nil || nodeType.Kind() != automerge.KindStr { + return false + } + + return nodeType.Str() == string(conf.PEER_ROLE) +} + +func (c *CrdtMeshManager) GetPeers() []string { keys, _ := c.doc.Path("nodes").Map().Keys() + + keys = lib.Filter(keys, func(s string) bool { + return c.isPeer(s) + }) + return keys } @@ -450,6 +471,12 @@ func (m *MeshNodeCrdt) GetServices() map[string]string { return services } +// GetType refers to the type of the node. Peer means that the node is globally accessible +// Client means the node is only accessible through another peer +func (n *MeshNodeCrdt) GetType() conf.NodeType { + return conf.NodeType(n.Type) +} + func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { nodes := make(map[string]mesh.MeshNode) @@ -464,6 +491,7 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { Description: node.Description, Alias: node.Alias, Services: node.GetServices(), + Type: node.Type, } } diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index 71148a6..b6861f8 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -38,6 +38,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod Routes: map[string]interface{}{}, Description: "", Alias: "", + Type: string(params.Role), } } diff --git a/pkg/automerge/types.go b/pkg/automerge/types.go index 2888b50..64b51c2 100644 --- a/pkg/automerge/types.go +++ b/pkg/automerge/types.go @@ -11,6 +11,7 @@ type MeshNodeCrdt struct { Alias string `automerge:"alias"` Description string `automerge:"description"` Services map[string]string `automerge:"services"` + Type string `automerge:"type"` } // MeshCrdt: Represents the mesh network as a whole diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index ad68721..1c3f4d6 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -16,6 +16,13 @@ func (m *WgMeshConfigurationError) Error() string { return m.msg } +type NodeType string + +const ( + PEER_ROLE NodeType = "peer" + CLIENT_ROLE NodeType = "client" +) + type WgMeshConfiguration struct { // CertificatePath is the path to the certificate to use in mTLS CertificatePath string `yaml:"certificatePath"` @@ -53,6 +60,12 @@ type WgMeshConfiguration struct { Profile bool `yaml:"profile"` // StubWg whether or not to stub the WireGuard types StubWg bool `yaml:"stubWg"` + // Role specifies whether or not the user is globally accessible. + // If the user is globaly accessible they specify themselves as a client. + Role NodeType `yaml:"role"` + // KeepAliveWg configures the implementation so that we send keep alive packets to peers. + // KeepAlive can only be set if role is type client + KeepAliveWg int `yaml:"keepAliveWg"` } func ValidateConfiguration(c *WgMeshConfiguration) error { @@ -134,6 +147,10 @@ func ValidateConfiguration(c *WgMeshConfiguration) error { } } + if c.Role == "" { + c.Role = PEER_ROLE + } + return nil } diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index e5acef3..fe779ee 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -35,7 +35,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ipAllocator := &ip.ULABuilder{} interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) - configApplyer := mesh.NewWgMeshConfigApplyer() + configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf) meshManagerParams := &mesh.NewMeshManagerParams{ Conf: *params.Conf, diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index 07487c6..a06515a 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -28,6 +28,9 @@ type JoinMeshArgs struct { // Endpoint is the routable address of this machine. If not provided // defaults to the default address Endpoint string + // Client specifies whether we should join as a client of the peer + // we are connecting to + Client bool } type PutServiceArgs struct { diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 35e9258..00f47a7 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -2,8 +2,13 @@ package mesh import ( "fmt" + "hash/fnv" "net" + "time" + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" + "github.com/tim-beatham/wgmesh/pkg/route" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -16,10 +21,12 @@ type MeshConfigApplyer interface { // WgMeshConfigApplyer applies WireGuard configuration type WgMeshConfigApplyer struct { - meshManager MeshManager + meshManager MeshManager + config *conf.WgMeshConfiguration + routeInstaller route.RouteInstaller } -func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { +func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) if err != nil { @@ -40,10 +47,13 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { allowedips = append(allowedips, *ipnet) } + keepAlive := time.Duration(m.config.KeepAliveTime) * time.Second + peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - Endpoint: endpoint, - AllowedIPs: allowedips, + PublicKey: pubKey, + Endpoint: endpoint, + AllowedIPs: allowedips, + PersistentKeepaliveInterval: &keepAlive, } return &peerConfig, nil @@ -56,13 +66,56 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } - nodes := snap.GetNodes() + nodes := lib.MapValues(snap.GetNodes()) peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) + peers := lib.Filter(nodes, func(mn MeshNode) bool { + return mn.GetType() == conf.PEER_ROLE + }) + var count int = 0 + self, err := m.meshManager.GetSelf(mesh.GetMeshId()) + + if err != nil { + return err + } + + rtnl, err := lib.NewRtNetlinkConfig() + + if err != nil { + return err + } + for _, n := range nodes { - peer, err := convertMeshNode(n) + if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 { + a := fnv.New32a() + a.Write([]byte(n.GetHostEndpoint())) + sum := a.Sum32() + + responsiblePeer := peers[int(sum)%len(peers)] + + if responsiblePeer.GetHostEndpoint() != self.GetHostEndpoint() { + dev, err := mesh.GetDevice() + + if err != nil { + return err + } + + rtnl.AddRoute(dev.Name, lib.Route{ + Gateway: responsiblePeer.GetWgHost().IP, + Destination: *n.GetWgHost(), + }) + + if err != nil { + return err + } + + continue + } + } + + peer, err := m.convertMeshNode(n) if err != nil { return err @@ -122,6 +175,9 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) { m.meshManager = manager } -func NewWgMeshConfigApplyer() MeshConfigApplyer { - return &WgMeshConfigApplyer{} +func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer { + return &WgMeshConfigApplyer{ + config: config, + routeInstaller: route.NewRouteInstaller(), + } } diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 00fa08c..1fb9edf 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -256,6 +256,16 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) } + if params.WgPort == 0 && !s.conf.StubWg { + device, err := mesh.GetDevice() + + if err != nil { + return err + } + + params.WgPort = device.ListenPort + } + pubKey, err := s.GetPublicKey(params.MeshId) if err != nil { @@ -273,6 +283,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { NodeIP: nodeIP, WgPort: params.WgPort, Endpoint: params.Endpoint, + Role: s.conf.Role, }) if !s.conf.StubWg { diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 78515ba..7fe53e2 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -21,6 +21,11 @@ type MeshNodeStub struct { description string } +// GetType implements MeshNode. +func (*MeshNodeStub) GetType() conf.NodeType { + return PEER +} + // GetServices implements MeshNode. func (*MeshNodeStub) GetServices() map[string]string { return make(map[string]string) @@ -77,28 +82,28 @@ type MeshProviderStub struct { } // GetNodeIds implements MeshProvider. -func (*MeshProviderStub) GetNodeIds() []string { - panic("unimplemented") +func (*MeshProviderStub) GetPeers() []string { + return make([]string, 0) } // GetNode implements MeshProvider. func (*MeshProviderStub) GetNode(string) (MeshNode, error) { - panic("unimplemented") + return nil, nil } // NodeExists implements MeshProvider. func (*MeshProviderStub) NodeExists(string) bool { - panic("unimplemented") + return false } // AddService implements MeshProvider. func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { - panic("unimplemented") + return nil } // RemoveService implements MeshProvider. func (*MeshProviderStub) RemoveService(nodeId string, key string) error { - panic("unimplemented") + return nil } // SetAlias implements MeshProvider. @@ -108,7 +113,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { // RemoveRoutes implements MeshProvider. func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { - panic("unimplemented") + return nil } // Prune implements MeshProvider. diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 7413508..4613bfc 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -11,6 +11,13 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const ( + // Data Exchanged Between Peers + PEER conf.NodeType = "peer" + // Data Exchanged Between Clients + CLIENT conf.NodeType = "client" +) + // MeshNode represents an implementation of a node in a mesh type MeshNode interface { // GetHostEndpoint: gets the gRPC endpoint of the node @@ -34,6 +41,7 @@ type MeshNode interface { GetAlias() string // GetServices: returns a list of services offered by the node GetServices() map[string]string + GetType() conf.NodeType } // NodeEquals: determines if two mesh nodes are equivalent to one another @@ -129,7 +137,7 @@ type MeshProvider interface { // Prune: prunes all nodes that have not updated their timestamp in // pruneAmount seconds Prune(pruneAmount int) error - GetNodeIds() []string + GetPeers() []string } // HostParameters contains the IDs of a node @@ -158,6 +166,7 @@ type MeshNodeFactoryParams struct { NodeIP net.IP WgPort int Endpoint string + Role conf.NodeType } // MeshBuilder build the hosts mesh node for it to be added to the mesh diff --git a/pkg/query/query.go b/pkg/query/query.go index 52bb3a9..878aa16 100644 --- a/pkg/query/query.go +++ b/pkg/query/query.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/jmespath/go-jmespath" + "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/mesh" ) @@ -28,11 +29,12 @@ type QueryNode struct { PublicKey string `json:"publicKey"` WgEndpoint string `json:"wgEndpoint"` WgHost string `json:"wgHost"` - Timestamp int64 `json:"timestmap"` + Timestamp int64 `json:"timestamp"` Description string `json:"description"` Routes []string `json:"routes"` Alias string `json:"alias"` Services map[string]string `json:"services"` + Type conf.NodeType `json:"type"` } func (m *QueryError) Error() string { @@ -80,6 +82,7 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode { queryNode.Description = node.GetDescription() queryNode.Alias = node.GetAlias() queryNode.Services = node.GetServices() + queryNode.Type = node.GetType() return queryNode } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index ce8b23f..ac7c8cb 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -44,7 +44,7 @@ func (s *SyncerImpl) Sync(meshId string) error { } } - nodeNames := s.manager.GetMesh(meshId).GetNodeIds() + nodeNames := s.manager.GetMesh(meshId).GetPeers() self, err := s.manager.GetSelf(meshId) From 7b939e0468b07959f04c45436165f5a4d0465f9a Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 21 Nov 2023 20:42:43 +0000 Subject: [PATCH 2/3] 24-keepalive-holepunch Added the ability to hole punch NAT --- pkg/lib/hashing.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ pkg/mesh/config.go | 17 +++++++++-------- 2 files changed, 55 insertions(+), 8 deletions(-) create mode 100644 pkg/lib/hashing.go diff --git a/pkg/lib/hashing.go b/pkg/lib/hashing.go new file mode 100644 index 0000000..8bb40ab --- /dev/null +++ b/pkg/lib/hashing.go @@ -0,0 +1,46 @@ +package lib + +import ( + "hash/fnv" + "sort" +) + +type consistentHashRecord[V any] struct { + record V + value int +} + +func HashString(value string) int { + f := fnv.New32a() + f.Write([]byte(value)) + return int(f.Sum32()) +} + +// ConsistentHash implementation. Traverse the values until we find a key +// less than ours. +func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V { + if len(values) == 0 { + panic("values is empty") + } + + vs := Map(values, func(v V) consistentHashRecord[V] { + return consistentHashRecord[V]{ + v, + keyFunc(v), + } + }) + + sort.SliceStable(vs, func(i, j int) bool { + return vs[i].value < vs[j].value + }) + + ourKey := keyFunc(client) + + for _, record := range vs { + if ourKey < record.value { + return record.record + } + } + + return vs[0].record +} diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 00f47a7..55c0035 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -2,7 +2,6 @@ package mesh import ( "fmt" - "hash/fnv" "net" "time" @@ -88,14 +87,16 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } for _, n := range nodes { + if NodeEquals(n, self) { + continue + } + if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 { - a := fnv.New32a() - a.Write([]byte(n.GetHostEndpoint())) - sum := a.Sum32() + peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int { + return lib.HashString(mn.GetWgHost().String()) + }) - responsiblePeer := peers[int(sum)%len(peers)] - - if responsiblePeer.GetHostEndpoint() != self.GetHostEndpoint() { + if !NodeEquals(peer, self) { dev, err := mesh.GetDevice() if err != nil { @@ -103,7 +104,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { } rtnl.AddRoute(dev.Name, lib.Route{ - Gateway: responsiblePeer.GetWgHost().IP, + Gateway: peer.GetWgHost().IP, Destination: *n.GetWgHost(), }) From 624bd6e921080d141a841a1600aa2dcb989bdc04 Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 21 Nov 2023 21:26:31 +0000 Subject: [PATCH 3/3] 24-keepalive Persistent keep alive working --- pkg/mesh/config.go | 54 +++++++++++++++++++++++++++++----------------- pkg/mesh/types.go | 38 +------------------------------- 2 files changed, 35 insertions(+), 57 deletions(-) diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 55c0035..a784748 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -25,7 +25,7 @@ type WgMeshConfigApplyer struct { routeInstaller route.RouteInstaller } -func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { +func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) { endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) if err != nil { @@ -46,7 +46,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode) (*wgtypes.PeerConfi allowedips = append(allowedips, *ipnet) } - keepAlive := time.Duration(m.config.KeepAliveTime) * time.Second + clients, ok := peerToClients[node.GetWgHost().String()] + + if ok { + allowedips = append(allowedips, clients...) + } + + keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second peerConfig := wgtypes.PeerConfig{ PublicKey: pubKey, @@ -86,37 +92,45 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { return err } + peerToClients := make(map[string][]net.IPNet) + for _, n := range nodes { if NodeEquals(n, self) { continue } - if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 { + if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE { peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int { return lib.HashString(mn.GetWgHost().String()) }) - if !NodeEquals(peer, self) { - dev, err := mesh.GetDevice() + dev, err := mesh.GetDevice() - if err != nil { - return err - } - - rtnl.AddRoute(dev.Name, lib.Route{ - Gateway: peer.GetWgHost().IP, - Destination: *n.GetWgHost(), - }) - - if err != nil { - return err - } - - continue + if err != nil { + return err } + + rtnl.AddRoute(dev.Name, lib.Route{ + Gateway: peer.GetWgHost().IP, + Destination: *n.GetWgHost(), + }) + + if err != nil { + return err + } + + clients, ok := peerToClients[peer.GetWgHost().String()] + + if !ok { + clients = make([]net.IPNet, 0) + peerToClients[peer.GetWgHost().String()] = clients + } + + peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost()) + continue } - peer, err := m.convertMeshNode(n) + peer, err := m.convertMeshNode(n, peerToClients) if err != nil { return err diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 4613bfc..fe38b56 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -4,7 +4,6 @@ package mesh import ( "net" - "slices" "github.com/tim-beatham/wgmesh/pkg/conf" "golang.zx2c4.com/wireguard/wgctrl" @@ -46,42 +45,7 @@ type MeshNode interface { // NodeEquals: determines if two mesh nodes are equivalent to one another func NodeEquals(node1, node2 MeshNode) bool { - if node1.GetHostEndpoint() != node2.GetHostEndpoint() { - return false - } - - node1Pub, _ := node1.GetPublicKey() - node2Pub, _ := node2.GetPublicKey() - - if node1Pub != node2Pub { - return false - } - - if node1.GetWgEndpoint() != node2.GetWgEndpoint() { - return false - } - - if node1.GetWgHost() != node2.GetWgHost() { - return false - } - - if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) { - return false - } - - if node1.GetIdentifier() != node2.GetIdentifier() { - return false - } - - if node1.GetDescription() != node2.GetDescription() { - return false - } - - if node1.GetAlias() != node2.GetAlias() { - return false - } - - return true + return node1.GetHostEndpoint() == node2.GetHostEndpoint() } type MeshSnapshot interface {