1
0
forked from extern/smegmesh

Merge pull request #24 from tim-beatham/24-keepalive-holepunch

24 keepalive holepunch
This commit is contained in:
Tim Beatham 2023-11-21 21:28:16 +00:00 committed by GitHub
commit bf0724f6e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 217 additions and 58 deletions

View File

@ -42,8 +42,29 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(s string) bool {
return c.isPeer(s)
})
return keys return keys
} }
@ -450,6 +471,12 @@ func (m *MeshNodeCrdt) GetServices() map[string]string {
return services return services
} }
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -464,6 +491,7 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Description: node.Description, Description: node.Description,
Alias: node.Alias, Alias: node.Alias,
Services: node.GetServices(), Services: node.GetServices(),
Type: node.Type,
} }
} }

View File

@ -38,6 +38,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
Routes: map[string]interface{}{}, Routes: map[string]interface{}{},
Description: "", Description: "",
Alias: "", Alias: "",
Type: string(params.Role),
} }
} }

View File

@ -11,6 +11,7 @@ type MeshNodeCrdt struct {
Alias string `automerge:"alias"` Alias string `automerge:"alias"`
Description string `automerge:"description"` Description string `automerge:"description"`
Services map[string]string `automerge:"services"` Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,13 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg return m.msg
} }
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type WgMeshConfiguration struct { type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath"`
@ -53,6 +60,12 @@ type WgMeshConfiguration struct {
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -134,6 +147,10 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.Role == "" {
c.Role = PEER_ROLE
}
return nil return nil
} }

View File

@ -35,7 +35,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,

View File

@ -28,6 +28,9 @@ type JoinMeshArgs struct {
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided
// defaults to the default address // defaults to the default address
Endpoint string Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
} }
type PutServiceArgs struct { type PutServiceArgs struct {

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any](values []V, client V, keyFunc func(V) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
keyFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -3,7 +3,11 @@ package mesh
import ( import (
"fmt" "fmt"
"net" "net"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -17,9 +21,11 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
} }
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, peerToClients map[string][]net.IPNet) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil { if err != nil {
@ -40,10 +46,19 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips = append(allowedips, *ipnet) allowedips = append(allowedips, *ipnet)
} }
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey, PublicKey: pubKey,
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
} }
return &peerConfig, nil return &peerConfig, nil
@ -56,13 +71,66 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
nodes := snap.GetNodes() nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0 var count int = 0
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
for _, n := range nodes { for _, n := range nodes {
peer, err := convertMeshNode(n) if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
peer := lib.ConsistentHash(peers, n, func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
})
dev, err := mesh.GetDevice()
if err != nil {
return err
}
rtnl.AddRoute(dev.Name, lib.Route{
Gateway: peer.GetWgHost().IP,
Destination: *n.GetWgHost(),
})
if err != nil {
return err
}
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
peer, err := m.convertMeshNode(n, peerToClients)
if err != nil { if err != nil {
return err return err
@ -122,6 +190,9 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{} return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
} }

View File

@ -256,6 +256,16 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
} }
if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
params.WgPort = device.ListenPort
}
pubKey, err := s.GetPublicKey(params.MeshId) pubKey, err := s.GetPublicKey(params.MeshId)
if err != nil { if err != nil {
@ -273,6 +283,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
NodeIP: nodeIP, NodeIP: nodeIP,
WgPort: params.WgPort, WgPort: params.WgPort,
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
Role: s.conf.Role,
}) })
if !s.conf.StubWg { if !s.conf.StubWg {

View File

@ -21,6 +21,11 @@ type MeshNodeStub struct {
description string description string
} }
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return PEER
}
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return make(map[string]string)
@ -77,28 +82,28 @@ type MeshProviderStub struct {
} }
// GetNodeIds implements MeshProvider. // GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string { func (*MeshProviderStub) GetPeers() []string {
panic("unimplemented") return make([]string, 0)
} }
// GetNode implements MeshProvider. // GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) { func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented") return nil, nil
} }
// NodeExists implements MeshProvider. // NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool { func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented") return false
} }
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented") return nil
} }
// RemoveService implements MeshProvider. // RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error { func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented") return nil
} }
// SetAlias implements MeshProvider. // SetAlias implements MeshProvider.
@ -108,7 +113,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented") return nil
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.

View File

@ -4,13 +4,19 @@ package mesh
import ( import (
"net" "net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
const (
// Data Exchanged Between Peers
PEER conf.NodeType = "peer"
// Data Exchanged Between Clients
CLIENT conf.NodeType = "client"
)
// MeshNode represents an implementation of a node in a mesh // MeshNode represents an implementation of a node in a mesh
type MeshNode interface { type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node // GetHostEndpoint: gets the gRPC endpoint of the node
@ -34,46 +40,12 @@ type MeshNode interface {
GetAlias() string GetAlias() string
// GetServices: returns a list of services offered by the node // GetServices: returns a list of services offered by the node
GetServices() map[string]string GetServices() map[string]string
GetType() conf.NodeType
} }
// NodeEquals: determines if two mesh nodes are equivalent to one another // NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool { func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() { return node1.GetHostEndpoint() == node2.GetHostEndpoint()
return false
}
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
if node1Pub != node2Pub {
return false
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -129,7 +101,7 @@ type MeshProvider interface {
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
GetNodeIds() []string GetPeers() []string
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
@ -158,6 +130,7 @@ type MeshNodeFactoryParams struct {
NodeIP net.IP NodeIP net.IP
WgPort int WgPort int
Endpoint string Endpoint string
Role conf.NodeType
} }
// MeshBuilder build the hosts mesh node for it to be added to the mesh // MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
@ -28,11 +29,12 @@ type QueryNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []string `json:"routes"`
Alias string `json:"alias"` Alias string `json:"alias"`
Services map[string]string `json:"services"` Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -80,6 +82,7 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias() queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices() queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode return queryNode
} }

View File

@ -44,7 +44,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
} }
nodeNames := s.manager.GetMesh(meshId).GetNodeIds() nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)