mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-08-12 14:37:10 +02:00
Compare commits
14 Commits
41-bugfix-
...
53-run-com
Author | SHA1 | Date | |
---|---|---|---|
fe14f63217 | |||
4a8a39601f | |||
1e263cc6a8 | |||
dae9cd31a1 | |||
f855f53fbf | |||
52feb5767b | |||
815c4484ee | |||
0058c9f4c9 | |||
92c0805275 | |||
661fb0d54c | |||
64885f1055 | |||
2169f7796f | |||
a3ceff019d | |||
b78d96986c |
@ -4,13 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
ipcRpc "net/rpc"
|
ipcRpc "net/rpc"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/akamensky/argparse"
|
"github.com/akamensky/argparse"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,6 +16,7 @@ type CreateMeshParams struct {
|
|||||||
Client *ipcRpc.Client
|
Client *ipcRpc.Client
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
Role string
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMesh(args *CreateMeshParams) string {
|
func createMesh(args *CreateMeshParams) string {
|
||||||
@ -27,6 +24,7 @@ func createMesh(args *CreateMeshParams) string {
|
|||||||
newMeshParams := ipc.NewMeshArgs{
|
newMeshParams := ipc.NewMeshArgs{
|
||||||
WgPort: args.WgPort,
|
WgPort: args.WgPort,
|
||||||
Endpoint: args.Endpoint,
|
Endpoint: args.Endpoint,
|
||||||
|
Role: args.Role,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
|
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
|
||||||
@ -60,6 +58,7 @@ type JoinMeshParams struct {
|
|||||||
IfName string
|
IfName string
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
Role string
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinMesh(params *JoinMeshParams) string {
|
func joinMesh(params *JoinMeshParams) string {
|
||||||
@ -69,6 +68,7 @@ func joinMesh(params *JoinMeshParams) string {
|
|||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
IpAdress: params.IpAddress,
|
IpAdress: params.IpAddress,
|
||||||
Port: params.WgPort,
|
Port: params.WgPort,
|
||||||
|
Role: params.Role,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
|
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
|
||||||
@ -80,34 +80,6 @@ func joinMesh(params *JoinMeshParams) string {
|
|||||||
return reply
|
return reply
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMesh(client *ipcRpc.Client, meshId string) {
|
|
||||||
reply := new(ipc.GetMeshReply)
|
|
||||||
|
|
||||||
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, node := range reply.Nodes {
|
|
||||||
fmt.Println("Public Key: " + node.PublicKey)
|
|
||||||
fmt.Println("Control Endpoint: " + node.HostEndpoint)
|
|
||||||
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
|
|
||||||
fmt.Println("Wg IP: " + node.WgHost)
|
|
||||||
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
|
|
||||||
|
|
||||||
mapFunc := func(r ctrlserver.MeshRoute) string {
|
|
||||||
return r.Destination
|
|
||||||
}
|
|
||||||
|
|
||||||
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
|
|
||||||
fmt.Printf("Routes: %s\n", advertiseRoutes)
|
|
||||||
|
|
||||||
fmt.Println("---")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func leaveMesh(client *ipcRpc.Client, meshId string) {
|
func leaveMesh(client *ipcRpc.Client, meshId string) {
|
||||||
var reply string
|
var reply string
|
||||||
|
|
||||||
@ -255,11 +227,13 @@ func main() {
|
|||||||
|
|
||||||
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
|
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
|
||||||
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
|
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
|
||||||
|
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
|
||||||
|
|
||||||
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
|
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
|
||||||
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
|
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
|
||||||
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
|
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
|
||||||
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
|
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
|
||||||
|
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
|
||||||
|
|
||||||
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
|
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
|
||||||
|
|
||||||
@ -300,6 +274,7 @@ func main() {
|
|||||||
Client: client,
|
Client: client,
|
||||||
WgPort: *newMeshPort,
|
WgPort: *newMeshPort,
|
||||||
Endpoint: *newMeshEndpoint,
|
Endpoint: *newMeshEndpoint,
|
||||||
|
Role: *newMeshRole,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,6 +289,7 @@ func main() {
|
|||||||
IpAddress: *joinMeshIpAddress,
|
IpAddress: *joinMeshIpAddress,
|
||||||
MeshId: *joinMeshId,
|
MeshId: *joinMeshId,
|
||||||
Endpoint: *joinMeshEndpoint,
|
Endpoint: *joinMeshEndpoint,
|
||||||
|
Role: *joinMeshRole,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,13 +19,13 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) != 2 {
|
if len(os.Args) != 2 {
|
||||||
logging.Log.WriteErrorf("Need to provide configuration.yaml")
|
logging.Log.WriteErrorf("Did not provide configuration")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conf, err := conf.ParseConfiguration(os.Args[1])
|
conf, err := conf.ParseDaemonConfiguration(os.Args[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof("Could not parse configuration")
|
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,8 +63,7 @@ func main() {
|
|||||||
syncRequester = sync.NewSyncRequester(ctrlServer)
|
syncRequester = sync.NewSyncRequester(ctrlServer)
|
||||||
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
|
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)
|
||||||
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
|
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer)
|
||||||
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
|
keepAlive := timer.NewTimestampScheduler(ctrlServer)
|
||||||
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
|
|
||||||
|
|
||||||
robinIpcParams := robin.RobinIpcParams{
|
robinIpcParams := robin.RobinIpcParams{
|
||||||
CtrlServer: ctrlServer,
|
CtrlServer: ctrlServer,
|
||||||
@ -82,13 +81,12 @@ func main() {
|
|||||||
|
|
||||||
go ipc.RunIpcHandler(&robinIpc)
|
go ipc.RunIpcHandler(&robinIpc)
|
||||||
go syncScheduler.Run()
|
go syncScheduler.Run()
|
||||||
go timestampScheduler.Run()
|
go keepAlive.Run()
|
||||||
go pruneScheduler.Run()
|
|
||||||
|
|
||||||
closeResources := func() {
|
closeResources := func() {
|
||||||
logging.Log.WriteInfof("Closing resources")
|
logging.Log.WriteInfof("Closing resources")
|
||||||
syncScheduler.Stop()
|
syncScheduler.Stop()
|
||||||
timestampScheduler.Stop()
|
keepAlive.Stop()
|
||||||
ctrlServer.Close()
|
ctrlServer.Close()
|
||||||
client.Close()
|
client.Close()
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ type CrdtMeshManager struct {
|
|||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
doc *automerge.Doc
|
doc *automerge.Doc
|
||||||
LastHash automerge.ChangeHash
|
LastHash automerge.ChangeHash
|
||||||
conf *conf.WgMeshConfiguration
|
conf *conf.WgConfiguration
|
||||||
cache *MeshCrdt
|
cache *MeshCrdt
|
||||||
lastCacheHash automerge.ChangeHash
|
lastCacheHash automerge.ChangeHash
|
||||||
}
|
}
|
||||||
@ -74,8 +74,8 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
keepAliveTime := timestamp.Int64()
|
return true
|
||||||
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
|
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CrdtMeshManager) GetPeers() []string {
|
func (c *CrdtMeshManager) GetPeers() []string {
|
||||||
@ -135,7 +135,7 @@ type NewCrdtNodeMangerParams struct {
|
|||||||
MeshId string
|
MeshId string
|
||||||
DevName string
|
DevName string
|
||||||
Port int
|
Port int
|
||||||
Conf conf.WgMeshConfiguration
|
Conf *conf.WgConfiguration
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
|
|||||||
manager.doc = automerge.New()
|
manager.doc = automerge.New()
|
||||||
manager.IfName = params.DevName
|
manager.IfName = params.DevName
|
||||||
manager.Client = params.Client
|
manager.Client = params.Client
|
||||||
manager.conf = ¶ms.Conf
|
manager.conf = params.Conf
|
||||||
manager.cache = nil
|
manager.cache = nil
|
||||||
return &manager, nil
|
return &manager, nil
|
||||||
}
|
}
|
||||||
@ -449,7 +449,7 @@ func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes deletes the specified routes
|
// DeleteRoutes deletes the specified routes
|
||||||
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -467,65 +467,26 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
err = routeMap.Map().Delete(route)
|
err = routeMap.Map().Delete(route.GetDestination().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfiguration: gets the configuration for this mesh network
|
||||||
|
func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
|
||||||
|
return m.conf
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark: mark the node as locally dead
|
||||||
|
func (m *CrdtMeshManager) Mark(nodeId string) {
|
||||||
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
||||||
return NewAutomergeSync(m)
|
return NewAutomergeSync(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtMeshManager) Prune(pruneTime int) error {
|
func (m *CrdtMeshManager) Prune() error {
|
||||||
nodes, err := m.doc.Path("nodes").Get()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if nodes.Kind() != automerge.KindMap {
|
|
||||||
return errors.New("node must be a map")
|
|
||||||
}
|
|
||||||
|
|
||||||
values, err := nodes.Map().Values()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
deletionNodes := make([]string, 0)
|
|
||||||
|
|
||||||
for nodeId, node := range values {
|
|
||||||
if node.Kind() != automerge.KindMap {
|
|
||||||
return errors.New("node must be a map")
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeMap := node.Map()
|
|
||||||
|
|
||||||
timeStamp, err := nodeMap.Get("timestamp")
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if timeStamp.Kind() != automerge.KindInt64 {
|
|
||||||
return errors.New("timestamp is not int64")
|
|
||||||
}
|
|
||||||
|
|
||||||
timeValue := timeStamp.Int64()
|
|
||||||
nowValue := time.Now().Unix()
|
|
||||||
|
|
||||||
if nowValue-timeValue >= int64(pruneTime) {
|
|
||||||
deletionNodes = append(deletionNodes, nodeId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, node := range deletionNodes {
|
|
||||||
logging.Log.WriteInfof("Pruning %s", node)
|
|
||||||
nodes.Map().Delete(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ func setUpTests() *TestParams {
|
|||||||
DevName: "wg0",
|
DevName: "wg0",
|
||||||
Port: 5000,
|
Port: 5000,
|
||||||
Client: nil,
|
Client: nil,
|
||||||
Conf: conf.WgMeshConfiguration{},
|
Conf: conf.DaemonConfiguration{},
|
||||||
})
|
})
|
||||||
|
|
||||||
return &TestParams{
|
return &TestParams{
|
||||||
|
@ -14,13 +14,13 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
|
|||||||
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
|
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
|
||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
DevName: params.DevName,
|
DevName: params.DevName,
|
||||||
Conf: *params.Conf,
|
Conf: params.Conf,
|
||||||
Client: params.Client,
|
Client: params.Client,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type MeshNodeFactory struct {
|
type MeshNodeFactory struct {
|
||||||
Config conf.WgMeshConfiguration
|
Config conf.DaemonConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build builds the mesh node that represents the host machine to add
|
// Build builds the mesh node that represents the host machine to add
|
||||||
@ -30,7 +30,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
|
|||||||
|
|
||||||
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
|
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
|
||||||
|
|
||||||
if f.Config.Role == conf.CLIENT_ROLE {
|
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
|
||||||
grpcEndpoint = "-"
|
grpcEndpoint = "-"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
|
|||||||
Routes: make(map[string]Route),
|
Routes: make(map[string]Route),
|
||||||
Description: "",
|
Description: "",
|
||||||
Alias: "",
|
Alias: "",
|
||||||
Type: string(f.Config.Role),
|
Type: string(*params.MeshConfig.Role),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,12 +54,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
|
|||||||
|
|
||||||
if params.Endpoint != "" {
|
if params.Endpoint != "" {
|
||||||
hostName = params.Endpoint
|
hostName = params.Endpoint
|
||||||
} else if len(f.Config.Endpoint) != 0 {
|
} else if len(*params.MeshConfig.Endpoint) != 0 {
|
||||||
hostName = f.Config.Endpoint
|
hostName = *params.MeshConfig.Endpoint
|
||||||
} else {
|
} else {
|
||||||
ipFunc := lib.GetPublicIP
|
ipFunc := lib.GetPublicIP
|
||||||
|
|
||||||
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
|
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
|
||||||
ipFunc = lib.GetOutboundIP
|
ipFunc = lib.GetOutboundIP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
33
pkg/cmd/cmd.go
Normal file
33
pkg/cmd/cmd.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// cmd is a package for running commands in the different operating systems implementations
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CmdRunner interface {
|
||||||
|
RunCommands(commands ...string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnixCmdRunner struct{}
|
||||||
|
|
||||||
|
// RunCommand: runs the unix command. It splits the command into fields
|
||||||
|
// and then runs the command accordingly
|
||||||
|
func RunCommand(cmd string) error {
|
||||||
|
args := strings.Fields(cmd)
|
||||||
|
c := exec.Command(args[0], args[1:]...)
|
||||||
|
return c.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
|
||||||
|
for _, cmd := range commands {
|
||||||
|
err := RunCommand(cmd)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
279
pkg/conf/conf.go
279
pkg/conf/conf.go
@ -4,7 +4,7 @@ package conf
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
"github.com/go-playground/validator/v10"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,170 +30,187 @@ const (
|
|||||||
DNS_IP_DISCOVERY = "dns"
|
DNS_IP_DISCOVERY = "dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgMeshConfiguration struct {
|
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
|
||||||
|
// tell if the attribute is set
|
||||||
|
type WgConfiguration struct {
|
||||||
|
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
|
||||||
|
// service for IPDiscoverability
|
||||||
|
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"`
|
||||||
|
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
|
||||||
|
AdvertiseRoutes *bool `yaml:"advertiseRoutes"`
|
||||||
|
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
|
||||||
|
// for all nodes to route their packets to
|
||||||
|
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"`
|
||||||
|
// Endpoint contains what value should be set as the public endpoint of this node
|
||||||
|
Endpoint *string `yaml:"publicEndpoint"`
|
||||||
|
// Role specifies whether or not the user is globally accessible.
|
||||||
|
// If the user is globaly accessible they specify themselves as a client.
|
||||||
|
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
|
||||||
|
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
|
||||||
|
// KeepAlive can only be set if role is type client
|
||||||
|
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
|
||||||
|
// PreUp are WireGuard commands to run before adding the WG interface
|
||||||
|
PreUp []string `yaml:"preUp"`
|
||||||
|
// PostUp are WireGuard commands to run after adding the WG interface
|
||||||
|
PostUp []string `yaml:"postUp"`
|
||||||
|
// PreDown are WireGuard commands to run prior to removing the WG interface
|
||||||
|
PreDown []string `yaml:"preDown"`
|
||||||
|
// PostDown are WireGuard command to run after removing the WG interface
|
||||||
|
PostDown []string `yaml:"postDown"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DaemonConfiguration struct {
|
||||||
// CertificatePath is the path to the certificate to use in mTLS
|
// CertificatePath is the path to the certificate to use in mTLS
|
||||||
CertificatePath string `yaml:"certificatePath"`
|
CertificatePath string `yaml:"certificatePath" validate:"required,file"`
|
||||||
// PrivateKeypath is the path to the clients private key in mTLS
|
// PrivateKeypath is the path to the clients private key in mTLS
|
||||||
PrivateKeyPath string `yaml:"privateKeyPath"`
|
PrivateKeyPath string `yaml:"privateKeyPath" validate:"required,file"`
|
||||||
// CaCeritifcatePath path to the certificate of the trust certificate authority
|
// CaCeritifcatePath path to the certificate of the trust certificate authority
|
||||||
CaCertificatePath string `yaml:"caCertificatePath"`
|
CaCertificatePath string `yaml:"caCertificatePath" validate:"required,file"`
|
||||||
// SkipCertVerification specify to skip certificate verification. Should only be used
|
// SkipCertVerification specify to skip certificate verification. Should only be used
|
||||||
// in test environments
|
// in test environments
|
||||||
SkipCertVerification bool `yaml:"skipCertVerification"`
|
SkipCertVerification bool `yaml:"skipCertVerification"`
|
||||||
// Port to run the GrpcServer on
|
// Port to run the GrpcServer on
|
||||||
GrpcPort string `yaml:"gRPCPort"`
|
GrpcPort int `yaml:"gRPCPort" validate:"required"`
|
||||||
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
|
// Timeout number of seconds without response that a node is considered unreachable by gRPC
|
||||||
// use public IP discovery library
|
Timeout int `yaml:"timeout" validate:"required,gte=1"`
|
||||||
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
|
|
||||||
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
|
|
||||||
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
|
|
||||||
// Endpoint is the IP in which this computer is publicly reachable.
|
|
||||||
// usecase is when the node has multiple IP addresses
|
|
||||||
Endpoint string `yaml:"publicEndpoint"`
|
|
||||||
// ClusterSize size of the cluster to split on
|
|
||||||
ClusterSize int `yaml:"clusterSize"`
|
|
||||||
// SyncRate number of times per second to perform a sync
|
|
||||||
SyncRate float64 `yaml:"syncRate"`
|
|
||||||
// InterClusterChance proability of inter-cluster communication in a sync round
|
|
||||||
InterClusterChance float64 `yaml:"interClusterChance"`
|
|
||||||
// BranchRate number of nodes to randomly communicate with
|
|
||||||
BranchRate int `yaml:"branchRate"`
|
|
||||||
// InfectionCount number of times we sync before we can no longer catch the udpate
|
|
||||||
InfectionCount int `yaml:"infectionCount"`
|
|
||||||
// KeepAliveTime number of seconds before we update node indicating that we are still alive
|
|
||||||
KeepAliveTime int `yaml:"keepAliveTime"`
|
|
||||||
// Timeout number of seconds before we consider the node as dead
|
|
||||||
Timeout int `yaml:"timeout"`
|
|
||||||
// PruneTime number of seconds before we remove nodes that are likely to be dead
|
|
||||||
PruneTime int `yaml:"pruneTime"`
|
|
||||||
// DeadTime: number of seconds before we consider the node as dead and stop considering it
|
|
||||||
// when picking a random peer
|
|
||||||
DeadTime int `yaml:"deadTime"`
|
|
||||||
// Profile whether or not to include a http server that profiles the code
|
// Profile whether or not to include a http server that profiles the code
|
||||||
Profile bool `yaml:"profile"`
|
Profile bool `yaml:"profile"`
|
||||||
// StubWg whether or not to stub the WireGuard types
|
// StubWg whether or not to stub the WireGuard types
|
||||||
StubWg bool `yaml:"stubWg"`
|
StubWg bool `yaml:"stubWg"`
|
||||||
// Role specifies whether or not the user is globally accessible.
|
// SyncRate specifies how long the minimum time should be between synchronisation
|
||||||
// If the user is globaly accessible they specify themselves as a client.
|
SyncRate int `yaml:"syncRate" validate:"required,gte=1"`
|
||||||
Role NodeType `yaml:"role"`
|
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to
|
||||||
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
|
// send to every member in the mesh
|
||||||
// KeepAlive can only be set if role is type client
|
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"`
|
||||||
KeepAliveWg int `yaml:"keepAliveWg"`
|
// ClusterSize specifies how many neighbours you should synchronise with per round
|
||||||
|
ClusterSize int `yaml:"clusterSize" valdiate:"required,gt=0"`
|
||||||
|
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
|
||||||
|
InterClusterChance float64 `yaml:"interClusterChance" valdiate:"required,gt=0"`
|
||||||
|
// BranchRate specifies the number of nodes to synchronise with when a node has
|
||||||
|
// new changes to send to the mesh
|
||||||
|
BranchRate int `yaml:"branchRate" validate:"required,gte=1"`
|
||||||
|
// InfectionCount: number of time to sync before an update can no longer be 'caught'
|
||||||
|
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
|
||||||
|
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
|
||||||
|
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateConfiguration(c *WgMeshConfiguration) error {
|
// ValdiateMeshConfiguration: validates the mesh configuration
|
||||||
if len(c.CertificatePath) == 0 {
|
func ValidateMeshConfiguration(conf *WgConfiguration) error {
|
||||||
return &WgMeshConfigurationError{
|
validate := validator.New(validator.WithRequiredStructEnabled())
|
||||||
msg: "A public certificate must be specified for mTLS",
|
err := validate.Struct(conf)
|
||||||
}
|
|
||||||
|
if conf.PostDown == nil {
|
||||||
|
conf.PostDown = make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.PrivateKeyPath) == 0 {
|
if conf.PostUp == nil {
|
||||||
return &WgMeshConfigurationError{
|
conf.PostUp = make([]string, 0)
|
||||||
msg: "A private key must be specified for mTLS",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.CaCertificatePath) == 0 {
|
if conf.PreDown == nil {
|
||||||
return &WgMeshConfigurationError{
|
conf.PreDown = make([]string, 0)
|
||||||
msg: "A ca certificate must be specified for mTLS",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.GrpcPort) == 0 {
|
if conf.PreUp == nil {
|
||||||
return &WgMeshConfigurationError{
|
conf.PreUp = make([]string, 0)
|
||||||
msg: "A grpc port must be specified",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.ClusterSize <= 0 {
|
return err
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "A cluster size must not be 0",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.SyncRate <= 0 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "SyncRate cannot be negative",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.BranchRate <= 0 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Branch rate cannot be negative",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.InfectionCount <= 0 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Infection count cannot be less than 1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.KeepAliveTime <= 0 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "KeepAliveRate cannot be less than negative",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.InterClusterChance <= 0 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Intercluster chance cannot be less than 0",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Timeout < 1 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Timeout should be greater than or equal to 1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.PruneTime < 1 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Prune time cannot be < 1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.DeadTime < 1 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Dead time cannot be < 1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.KeepAliveTime <= 1 {
|
|
||||||
return &WgMeshConfigurationError{
|
|
||||||
msg: "Prune time cannot be less than keep alive time",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Role == "" {
|
|
||||||
c.Role = PEER_ROLE
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.IPDiscovery == "" {
|
|
||||||
c.IPDiscovery = PUBLIC_IP_DISCOVERY
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseConfiguration parses the mesh configuration
|
// ValidateDaemonConfiguration: validates the dameon configuration that is used.
|
||||||
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
|
func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
|
||||||
var conf WgMeshConfiguration
|
validate := validator.New(validator.WithRequiredStructEnabled())
|
||||||
|
err := validate.Struct(c)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
|
||||||
|
// keepalive time, role and so forth.
|
||||||
|
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
|
||||||
|
var conf WgConfiguration
|
||||||
|
|
||||||
yamlBytes, err := os.ReadFile(filePath)
|
yamlBytes, err := os.ReadFile(filePath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = yaml.Unmarshal(yamlBytes, &conf)
|
err = yaml.Unmarshal(yamlBytes, &conf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &conf, ValidateConfiguration(&conf)
|
return &conf, ValidateMeshConfiguration(&conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
|
||||||
|
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
|
||||||
|
var conf DaemonConfiguration
|
||||||
|
|
||||||
|
yamlBytes, err := os.ReadFile(filePath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = yaml.Unmarshal(yamlBytes, &conf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &conf, ValidateDaemonConfiguration(&conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergemeshConfiguration: merges the configuration in precedence where the last
|
||||||
|
// element in the list takes the most and the first takes the least
|
||||||
|
func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) {
|
||||||
|
var result WgConfiguration
|
||||||
|
|
||||||
|
for _, cfg := range cfgs {
|
||||||
|
if cfg.AdvertiseDefaultRoute != nil {
|
||||||
|
result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.AdvertiseRoutes != nil {
|
||||||
|
result.AdvertiseRoutes = cfg.AdvertiseRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Endpoint != nil {
|
||||||
|
result.Endpoint = cfg.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.IPDiscovery != nil {
|
||||||
|
result.IPDiscovery = cfg.IPDiscovery
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.KeepAliveWg != nil {
|
||||||
|
result.KeepAliveWg = cfg.KeepAliveWg
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PostDown != nil {
|
||||||
|
result.PostDown = cfg.PostDown
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PostUp != nil {
|
||||||
|
result.PostUp = cfg.PostUp
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PreDown != nil {
|
||||||
|
result.PreDown = cfg.PreDown
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PreUp != nil {
|
||||||
|
result.PreUp = cfg.PreUp
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Role != nil {
|
||||||
|
result.Role = cfg.Role
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, ValidateMeshConfiguration(&result)
|
||||||
}
|
}
|
||||||
|
@ -2,23 +2,12 @@ package conf
|
|||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func getExampleConfiguration() *WgMeshConfiguration {
|
func getExampleConfiguration() *DaemonConfiguration {
|
||||||
return &WgMeshConfiguration{
|
return &DaemonConfiguration{
|
||||||
CertificatePath: "./cert/cert.pem",
|
CertificatePath: "./cert/cert.pem",
|
||||||
PrivateKeyPath: "./cert/key.pem",
|
PrivateKeyPath: "./cert/key.pem",
|
||||||
CaCertificatePath: "./cert/ca.pems",
|
CaCertificatePath: "./cert/ca.pems",
|
||||||
SkipCertVerification: true,
|
SkipCertVerification: true,
|
||||||
GrpcPort: "8080",
|
|
||||||
AdvertiseRoutes: true,
|
|
||||||
Endpoint: "localhost",
|
|
||||||
ClusterSize: 1,
|
|
||||||
SyncRate: 1,
|
|
||||||
InterClusterChance: 0.1,
|
|
||||||
BranchRate: 2,
|
|
||||||
KeepAliveTime: 4,
|
|
||||||
InfectionCount: 1,
|
|
||||||
Timeout: 2,
|
|
||||||
PruneTime: 20,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,7 +15,7 @@ func TestConfigurationCertificatePathEmpty(t *testing.T) {
|
|||||||
conf := getExampleConfiguration()
|
conf := getExampleConfiguration()
|
||||||
conf.CertificatePath = ""
|
conf.CertificatePath = ""
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
err := ValidateDaemonConfiguration(conf)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(`error should be thrown`)
|
t.Fatal(`error should be thrown`)
|
||||||
@ -37,7 +26,7 @@ func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
|
|||||||
conf := getExampleConfiguration()
|
conf := getExampleConfiguration()
|
||||||
conf.PrivateKeyPath = ""
|
conf.PrivateKeyPath = ""
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
err := ValidateDaemonConfiguration(conf)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(`error should be thrown`)
|
t.Fatal(`error should be thrown`)
|
||||||
@ -48,7 +37,7 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
|
|||||||
conf := getExampleConfiguration()
|
conf := getExampleConfiguration()
|
||||||
conf.CaCertificatePath = ""
|
conf.CaCertificatePath = ""
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
err := ValidateDaemonConfiguration(conf)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(`error should be thrown`)
|
t.Fatal(`error should be thrown`)
|
||||||
@ -57,109 +46,21 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
|
|||||||
|
|
||||||
func TestConfigurationGrpcPortEmpty(t *testing.T) {
|
func TestConfigurationGrpcPortEmpty(t *testing.T) {
|
||||||
conf := getExampleConfiguration()
|
conf := getExampleConfiguration()
|
||||||
conf.GrpcPort = ""
|
conf.GrpcPort = 0
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
err := ValidateDaemonConfiguration(conf)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(`error should be thrown`)
|
t.Fatal(`error should be thrown`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClusterSizeZero(t *testing.T) {
|
func TestValidConfiguration(t *testing.T) {
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.ClusterSize = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SyncRateZero(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.SyncRate = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BranchRateZero(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.BranchRate = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func InfectionCountZero(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.InfectionCount = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func KeepAliveRateZero(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.KeepAliveTime = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidCOnfiguration(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
conf := getExampleConfiguration()
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
err := ValidateDaemonConfiguration(conf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeout(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.Timeout = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal(`error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPruneTimeZero(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.PruneTime = 0
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf(`Error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
|
|
||||||
conf := getExampleConfiguration()
|
|
||||||
conf.PruneTime = 1
|
|
||||||
|
|
||||||
err := ValidateConfiguration(conf)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf(`Error should be thrown`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -2,6 +2,7 @@ package conn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
@ -21,13 +22,13 @@ type ConnectionServer struct {
|
|||||||
ctrlProvider rpc.MeshCtrlServerServer
|
ctrlProvider rpc.MeshCtrlServerServer
|
||||||
// the sync service to synchronise nodes
|
// the sync service to synchronise nodes
|
||||||
syncProvider rpc.SyncServiceServer
|
syncProvider rpc.SyncServiceServer
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.DaemonConfiguration
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConnectionServerParams contains params for creating a new connection server
|
// NewConnectionServerParams contains params for creating a new connection server
|
||||||
type NewConnectionServerParams struct {
|
type NewConnectionServerParams struct {
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.DaemonConfiguration
|
||||||
CtrlProvider rpc.MeshCtrlServerServer
|
CtrlProvider rpc.MeshCtrlServerServer
|
||||||
SyncProvider rpc.SyncServiceServer
|
SyncProvider rpc.SyncServiceServer
|
||||||
}
|
}
|
||||||
@ -76,10 +77,10 @@ func (s *ConnectionServer) Listen() error {
|
|||||||
|
|
||||||
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
|
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
|
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))
|
||||||
s.listener = lis
|
s.listener = lis
|
||||||
|
|
||||||
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort)
|
logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteErrorf(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -48,6 +49,13 @@ type MeshNode struct {
|
|||||||
Description string
|
Description string
|
||||||
Services map[string]string
|
Services map[string]string
|
||||||
Type string
|
Type string
|
||||||
|
Tombstone bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark: marks the node is unreachable. This is not broadcast on
|
||||||
|
// syncrhonisation
|
||||||
|
func (m *TwoPhaseStoreMeshManager) Mark(nodeId string) {
|
||||||
|
m.store.Mark(nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHostEndpoint: gets the gRPC endpoint of the node
|
// GetHostEndpoint: gets the gRPC endpoint of the node
|
||||||
@ -146,12 +154,13 @@ func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseStoreMeshManager struct {
|
type TwoPhaseStoreMeshManager struct {
|
||||||
MeshId string
|
MeshId string
|
||||||
IfName string
|
IfName string
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
LastClock uint64
|
LastClock uint64
|
||||||
conf *conf.WgMeshConfiguration
|
conf *conf.WgConfiguration
|
||||||
store *TwoPhaseMap[string, MeshNode]
|
daemonConf *conf.DaemonConfiguration
|
||||||
|
store *TwoPhaseMap[string, MeshNode]
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddNode() adds a node to the mesh
|
// AddNode() adds a node to the mesh
|
||||||
@ -171,8 +180,16 @@ func (m *TwoPhaseStoreMeshManager) AddNode(node mesh.MeshNode) {
|
|||||||
|
|
||||||
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
|
||||||
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
func (m *TwoPhaseStoreMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||||
|
nodes := m.store.AsList()
|
||||||
|
|
||||||
|
snapshot := make(map[string]MeshNode)
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
snapshot[node.PublicKey] = node
|
||||||
|
}
|
||||||
|
|
||||||
return &MeshSnapshot{
|
return &MeshSnapshot{
|
||||||
Nodes: m.store.AsMap(),
|
Nodes: snapshot,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,11 +217,11 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
|
|||||||
// Load() loads a mesh network
|
// Load() loads a mesh network
|
||||||
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
|
func (m *TwoPhaseStoreMeshManager) Load(bs []byte) error {
|
||||||
buf := bytes.NewBuffer(bs)
|
buf := bytes.NewBuffer(bs)
|
||||||
|
|
||||||
dec := gob.NewDecoder(buf)
|
dec := gob.NewDecoder(buf)
|
||||||
|
|
||||||
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
|
var snapshot TwoPhaseMapSnapshot[string, MeshNode]
|
||||||
err := dec.Decode(&snapshot)
|
err := dec.Decode(&snapshot)
|
||||||
|
|
||||||
m.store.Merge(snapshot)
|
m.store.Merge(snapshot)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -238,6 +255,31 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
|
|||||||
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort nodes by their public key
|
||||||
|
peers := m.GetPeers()
|
||||||
|
slices.Sort(peers)
|
||||||
|
|
||||||
|
if len(peers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerToUpdate := peers[0]
|
||||||
|
|
||||||
|
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) {
|
||||||
|
m.store.Mark(peerToUpdate)
|
||||||
|
|
||||||
|
if len(peers) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerToUpdate = peers[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerToUpdate != nodeId {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh causing node to update it's time stamp
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
node.Timestamp = time.Now().Unix()
|
node.Timestamp = time.Now().Unix()
|
||||||
m.store.Put(nodeId, node)
|
m.store.Put(nodeId, node)
|
||||||
@ -256,19 +298,30 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
|
|||||||
|
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
changes := false
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
node.Routes[route.GetDestination().String()] = Route{
|
prevRoute, ok := node.Routes[route.GetDestination().String()]
|
||||||
Destination: route.GetDestination().String(),
|
|
||||||
Path: route.GetPath(),
|
if !ok || route.GetHopCount() < prevRoute.GetHopCount() {
|
||||||
|
changes = true
|
||||||
|
|
||||||
|
node.Routes[route.GetDestination().String()] = Route{
|
||||||
|
Destination: route.GetDestination().String(),
|
||||||
|
Path: route.GetPath(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.store.Put(nodeId, node)
|
if changes {
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
|
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
|
||||||
if !m.store.Contains(nodeId) {
|
if !m.store.Contains(nodeId) {
|
||||||
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
|
||||||
}
|
}
|
||||||
@ -279,8 +332,15 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...string)
|
|||||||
|
|
||||||
node := m.store.Get(nodeId)
|
node := m.store.Get(nodeId)
|
||||||
|
|
||||||
|
changes := false
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
delete(node.Routes, route)
|
changes = true
|
||||||
|
delete(node.Routes, route.GetDestination().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes {
|
||||||
|
m.store.Put(nodeId, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -357,20 +417,27 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prune: prunes all nodes that have not updated their timestamp in
|
// Prune: prunes all nodes that have not updated their timestamp in
|
||||||
// pruneAmount seconds
|
func (m *TwoPhaseStoreMeshManager) Prune() error {
|
||||||
func (m *TwoPhaseStoreMeshManager) Prune(pruneAmount int) error {
|
m.store.Prune()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeers: get a list of contactable peers
|
// GetPeers: get a list of contactable peers
|
||||||
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
|
||||||
nodes := lib.MapValues(m.store.AsMap())
|
nodes := m.store.AsList()
|
||||||
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
nodes = lib.Filter(nodes, func(mn MeshNode) bool {
|
||||||
if mn.Type != string(conf.PEER_ROLE) {
|
if mn.Type != string(conf.PEER_ROLE) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Now().Unix()-mn.Timestamp < int64(m.conf.DeadTime)
|
// If the node is marked as unreachable don't consider it a peer.
|
||||||
|
// this help to optimize convergence time for unreachable nodes.
|
||||||
|
// However advertising it to other nodes could result in flapping.
|
||||||
|
if m.store.IsMarked(mn.PublicKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
return lib.Map(nodes, func(mn MeshNode) string {
|
return lib.Map(nodes, func(mn MeshNode) string {
|
||||||
@ -440,3 +507,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
|
|||||||
m.store.Remove(nodeId)
|
m.store.Remove(nodeId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfiguration implements mesh.MeshProvider.
|
||||||
|
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
|
||||||
|
return m.conf
|
||||||
|
}
|
||||||
|
@ -2,46 +2,56 @@ package crdt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TwoPhaseMapFactory struct{}
|
type TwoPhaseMapFactory struct {
|
||||||
|
Config *conf.DaemonConfiguration
|
||||||
|
}
|
||||||
|
|
||||||
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
|
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
|
||||||
return &TwoPhaseStoreMeshManager{
|
return &TwoPhaseStoreMeshManager{
|
||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
IfName: params.DevName,
|
IfName: params.DevName,
|
||||||
Client: params.Client,
|
Client: params.Client,
|
||||||
conf: params.Conf,
|
conf: params.Conf,
|
||||||
store: NewTwoPhaseMap[string, MeshNode](params.NodeID),
|
daemonConf: params.DaemonConf,
|
||||||
|
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
|
||||||
|
h := fnv.New64a()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
return h.Sum64()
|
||||||
|
}, uint64(3*f.Config.KeepAliveTime)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MeshNodeFactory struct {
|
type MeshNodeFactory struct {
|
||||||
Config conf.WgMeshConfiguration
|
Config conf.DaemonConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
|
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
|
||||||
hostName := f.getAddress(params)
|
hostName := f.getAddress(params)
|
||||||
|
|
||||||
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
|
grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
|
||||||
|
wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort)
|
||||||
|
|
||||||
if f.Config.Role == conf.CLIENT_ROLE {
|
if *params.MeshConfig.Role == conf.CLIENT_ROLE {
|
||||||
grpcEndpoint = "-"
|
grpcEndpoint = "-"
|
||||||
|
wgEndpoint = "-"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &MeshNode{
|
return &MeshNode{
|
||||||
HostEndpoint: grpcEndpoint,
|
HostEndpoint: grpcEndpoint,
|
||||||
PublicKey: params.PublicKey.String(),
|
PublicKey: params.PublicKey.String(),
|
||||||
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
|
WgEndpoint: wgEndpoint,
|
||||||
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
|
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
|
||||||
Routes: make(map[string]Route),
|
Routes: make(map[string]Route),
|
||||||
Description: "",
|
Description: "",
|
||||||
Alias: "",
|
Alias: "",
|
||||||
Type: string(f.Config.Role),
|
Type: string(*params.MeshConfig.Role),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,12 +61,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
|
|||||||
|
|
||||||
if params.Endpoint != "" {
|
if params.Endpoint != "" {
|
||||||
hostName = params.Endpoint
|
hostName = params.Endpoint
|
||||||
} else if len(f.Config.Endpoint) != 0 {
|
} else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 {
|
||||||
hostName = f.Config.Endpoint
|
hostName = *params.MeshConfig.Endpoint
|
||||||
} else {
|
} else {
|
||||||
ipFunc := lib.GetPublicIP
|
ipFunc := lib.GetPublicIP
|
||||||
|
|
||||||
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
|
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
|
||||||
ipFunc = lib.GetOutboundIP
|
ipFunc = lib.GetOutboundIP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,27 +2,29 @@
|
|||||||
package crdt
|
package crdt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bucket[D any] struct {
|
type Bucket[D any] struct {
|
||||||
Vector uint64
|
Vector uint64
|
||||||
Contents D
|
Contents D
|
||||||
|
Gravestone bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GMap is a set that can only grow in size
|
// GMap is a set that can only grow in size
|
||||||
type GMap[K comparable, D any] struct {
|
type GMap[K cmp.Ordered, D any] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
contents map[K]Bucket[D]
|
contents map[uint64]Bucket[D]
|
||||||
getClock func() uint64
|
clock *VectorClock[K]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Put(key K, value D) {
|
func (g *GMap[K, D]) Put(key K, value D) {
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
|
|
||||||
clock := g.getClock() + 1
|
clock := g.clock.IncrementClock()
|
||||||
|
|
||||||
g.contents[key] = Bucket[D]{
|
g.contents[g.clock.hashFunc(key)] = Bucket[D]{
|
||||||
Vector: clock,
|
Vector: clock,
|
||||||
Contents: value,
|
Contents: value,
|
||||||
}
|
}
|
||||||
@ -31,6 +33,10 @@ func (g *GMap[K, D]) Put(key K, value D) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Contains(key K) bool {
|
func (g *GMap[K, D]) Contains(key K) bool {
|
||||||
|
return g.contains(g.clock.hashFunc(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) contains(key uint64) bool {
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
_, ok := g.contents[key]
|
_, ok := g.contents[key]
|
||||||
@ -40,7 +46,7 @@ func (g *GMap[K, D]) Contains(key K) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
func (g *GMap[K, D]) put(key uint64, b Bucket[D]) {
|
||||||
g.lock.Lock()
|
g.lock.Lock()
|
||||||
|
|
||||||
if g.contents[key].Vector < b.Vector {
|
if g.contents[key].Vector < b.Vector {
|
||||||
@ -50,7 +56,7 @@ func (g *GMap[K, D]) put(key K, b Bucket[D]) {
|
|||||||
g.lock.Unlock()
|
g.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) get(key K) Bucket[D] {
|
func (g *GMap[K, D]) get(key uint64) Bucket[D] {
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
bucket := g.contents[key]
|
bucket := g.contents[key]
|
||||||
g.lock.RUnlock()
|
g.lock.RUnlock()
|
||||||
@ -59,13 +65,38 @@ func (g *GMap[K, D]) get(key K) Bucket[D] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Get(key K) D {
|
func (g *GMap[K, D]) Get(key K) D {
|
||||||
return g.get(key).Contents
|
return g.get(g.clock.hashFunc(key)).Contents
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Keys() []K {
|
func (g *GMap[K, D]) Mark(key K) {
|
||||||
|
g.lock.Lock()
|
||||||
|
bucket := g.contents[g.clock.hashFunc(key)]
|
||||||
|
bucket.Gravestone = true
|
||||||
|
g.contents[g.clock.hashFunc(key)] = bucket
|
||||||
|
g.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMarked: returns true if the node is marked
|
||||||
|
func (g *GMap[K, D]) IsMarked(key K) bool {
|
||||||
|
marked := false
|
||||||
|
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
contents := make([]K, len(g.contents))
|
bucket, ok := g.contents[g.clock.hashFunc(key)]
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
marked = bucket.Gravestone
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
|
||||||
|
return marked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Keys() []uint64 {
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
contents := make([]uint64, len(g.contents))
|
||||||
index := 0
|
index := 0
|
||||||
|
|
||||||
for key := range g.contents {
|
for key := range g.contents {
|
||||||
@ -77,8 +108,8 @@ func (g *GMap[K, D]) Keys() []K {
|
|||||||
return contents
|
return contents
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
|
||||||
buckets := make(map[K]Bucket[D])
|
buckets := make(map[uint64]Bucket[D])
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for key, value := range g.contents {
|
for key, value := range g.contents {
|
||||||
@ -89,8 +120,8 @@ func (g *GMap[K, D]) Save() map[K]Bucket[D] {
|
|||||||
return buckets
|
return buckets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
|
||||||
buckets := make(map[K]Bucket[D])
|
buckets := make(map[uint64]Bucket[D])
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@ -101,8 +132,8 @@ func (g *GMap[K, D]) SaveWithKeys(keys []K) map[K]Bucket[D] {
|
|||||||
return buckets
|
return buckets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
|
||||||
clock := make(map[K]uint64)
|
clock := make(map[uint64]uint64)
|
||||||
g.lock.RLock()
|
g.lock.RLock()
|
||||||
|
|
||||||
for key, bucket := range g.contents {
|
for key, bucket := range g.contents {
|
||||||
@ -113,9 +144,33 @@ func (g *GMap[K, D]) GetClock() map[K]uint64 {
|
|||||||
return clock
|
return clock
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGMap[K comparable, D any](getClock func() uint64) *GMap[K, D] {
|
func (g *GMap[K, D]) GetHash() uint64 {
|
||||||
|
hash := uint64(0)
|
||||||
|
|
||||||
|
g.lock.RLock()
|
||||||
|
|
||||||
|
for _, value := range g.contents {
|
||||||
|
hash += value.Vector
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.RUnlock()
|
||||||
|
return hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GMap[K, D]) Prune() {
|
||||||
|
stale := g.clock.getStale()
|
||||||
|
g.lock.Lock()
|
||||||
|
|
||||||
|
for _, outlier := range stale {
|
||||||
|
delete(g.contents, outlier)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGMap[K cmp.Ordered, D any](clock *VectorClock[K]) *GMap[K, D] {
|
||||||
return &GMap[K, D]{
|
return &GMap[K, D]{
|
||||||
contents: make(map[K]Bucket[D]),
|
contents: make(map[uint64]Bucket[D]),
|
||||||
getClock: getClock,
|
clock: clock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,33 +1,37 @@
|
|||||||
package crdt
|
package crdt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"cmp"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TwoPhaseMap[K comparable, D any] struct {
|
type TwoPhaseMap[K cmp.Ordered, D any] struct {
|
||||||
addMap *GMap[K, D]
|
addMap *GMap[K, D]
|
||||||
removeMap *GMap[K, bool]
|
removeMap *GMap[K, bool]
|
||||||
vectors map[K]uint64
|
Clock *VectorClock[K]
|
||||||
processId K
|
processId K
|
||||||
lock sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapSnapshot[K comparable, D any] struct {
|
type TwoPhaseMapSnapshot[K cmp.Ordered, D any] struct {
|
||||||
Add map[K]Bucket[D]
|
Add map[uint64]Bucket[D]
|
||||||
Remove map[K]Bucket[bool]
|
Remove map[uint64]Bucket[bool]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains checks whether the value exists in the map
|
// Contains checks whether the value exists in the map
|
||||||
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
|
||||||
if !m.addMap.Contains(key) {
|
return m.contains(m.Clock.hashFunc(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks whether the value exists in the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
|
||||||
|
if !m.addMap.contains(key) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
addValue := m.addMap.get(key)
|
addValue := m.addMap.get(key)
|
||||||
|
|
||||||
if !m.removeMap.Contains(key) {
|
if !m.removeMap.contains(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,32 +50,39 @@ func (m *TwoPhaseMap[K, D]) Get(key K) D {
|
|||||||
return m.addMap.Get(key)
|
return m.addMap.Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put places the key K in the map
|
func (m *TwoPhaseMap[K, D]) get(key uint64) D {
|
||||||
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
var result D
|
||||||
msgSequence := m.incrementClock()
|
|
||||||
|
|
||||||
m.lock.Lock()
|
if !m.contains(key) {
|
||||||
|
return result
|
||||||
if _, ok := m.vectors[key]; !ok {
|
|
||||||
m.vectors[key] = msgSequence
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.lock.Unlock()
|
return m.addMap.get(key).Contents
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put places the key K in the map
|
||||||
|
func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
|
||||||
|
msgSequence := m.Clock.IncrementClock()
|
||||||
|
m.Clock.Put(key, msgSequence)
|
||||||
m.addMap.Put(key, data)
|
m.addMap.Put(key, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *TwoPhaseMap[K, D]) Mark(key K) {
|
||||||
|
m.addMap.Mark(key)
|
||||||
|
}
|
||||||
|
|
||||||
// Remove removes the value from the map
|
// Remove removes the value from the map
|
||||||
func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
func (m *TwoPhaseMap[K, D]) Remove(key K) {
|
||||||
m.removeMap.Put(key, true)
|
m.removeMap.Put(key, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Keys() []K {
|
func (m *TwoPhaseMap[K, D]) keys() []uint64 {
|
||||||
keys := make([]K, 0)
|
keys := make([]uint64, 0)
|
||||||
|
|
||||||
addKeys := m.addMap.Keys()
|
addKeys := m.addMap.Keys()
|
||||||
|
|
||||||
for _, key := range addKeys {
|
for _, key := range addKeys {
|
||||||
if !m.Contains(key) {
|
if !m.contains(key) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,16 +92,16 @@ func (m *TwoPhaseMap[K, D]) Keys() []K {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) AsMap() map[K]D {
|
func (m *TwoPhaseMap[K, D]) AsList() []D {
|
||||||
theMap := make(map[K]D)
|
theList := make([]D, 0)
|
||||||
|
|
||||||
keys := m.Keys()
|
keys := m.keys()
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
theMap[key] = m.Get(key)
|
theList = append(theList, m.get(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
return theMap
|
return theList
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
|
||||||
@ -110,37 +121,21 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type TwoPhaseMapState[K comparable] struct {
|
type TwoPhaseMapState[K cmp.Ordered] struct {
|
||||||
AddContents map[K]uint64
|
Vectors map[uint64]uint64
|
||||||
RemoveContents map[K]uint64
|
AddContents map[uint64]uint64
|
||||||
|
RemoveContents map[uint64]uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) incrementClock() uint64 {
|
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
|
||||||
maxClock := uint64(0)
|
return m.addMap.IsMarked(key)
|
||||||
m.lock.Lock()
|
|
||||||
|
|
||||||
for _, value := range m.vectors {
|
|
||||||
maxClock = max(maxClock, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.vectors[m.processId] = maxClock + 1
|
|
||||||
m.lock.Unlock()
|
|
||||||
return maxClock
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHash: Get the hash of the current state of the map
|
// GetHash: Get the hash of the current state of the map
|
||||||
// Sums the current values of the vectors. Provides good approximation
|
// Sums the current values of the vectors. Provides good approximation
|
||||||
// of increasing numbers
|
// of increasing numbers
|
||||||
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
func (m *TwoPhaseMap[K, D]) GetHash() uint64 {
|
||||||
m.lock.RLock()
|
return (m.addMap.GetHash() + 1) * (m.removeMap.GetHash() + 1)
|
||||||
|
|
||||||
sum := lib.Reduce(uint64(0), lib.MapValues(m.vectors), func(sum uint64, current uint64) uint64 {
|
|
||||||
return current + sum
|
|
||||||
})
|
|
||||||
|
|
||||||
m.lock.RUnlock()
|
|
||||||
|
|
||||||
return sum
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetState: get the current vector clock of the add and remove
|
// GetState: get the current vector clock of the add and remove
|
||||||
@ -150,6 +145,7 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
|||||||
removeContents := m.removeMap.GetClock()
|
removeContents := m.removeMap.GetClock()
|
||||||
|
|
||||||
return &TwoPhaseMapState[K]{
|
return &TwoPhaseMapState[K]{
|
||||||
|
Vectors: m.Clock.GetClock(),
|
||||||
AddContents: addContents,
|
AddContents: addContents,
|
||||||
RemoveContents: removeContents,
|
RemoveContents: removeContents,
|
||||||
}
|
}
|
||||||
@ -157,8 +153,8 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
|
|||||||
|
|
||||||
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
|
||||||
mapState := &TwoPhaseMapState[K]{
|
mapState := &TwoPhaseMapState[K]{
|
||||||
AddContents: make(map[K]uint64),
|
AddContents: make(map[uint64]uint64),
|
||||||
RemoveContents: make(map[K]uint64),
|
RemoveContents: make(map[uint64]uint64),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range state.AddContents {
|
for key, value := range state.AddContents {
|
||||||
@ -169,7 +165,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range state.AddContents {
|
for key, value := range state.RemoveContents {
|
||||||
otherValue, ok := m.RemoveContents[key]
|
otherValue, ok := m.RemoveContents[key]
|
||||||
|
|
||||||
if !ok || otherValue < value {
|
if !ok || otherValue < value {
|
||||||
@ -181,31 +177,35 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
|
||||||
m.lock.Lock()
|
|
||||||
|
|
||||||
for key, value := range snapshot.Add {
|
for key, value := range snapshot.Add {
|
||||||
|
// Gravestone is local only to that node.
|
||||||
|
// Discover ourselves if the node is alive
|
||||||
m.addMap.put(key, value)
|
m.addMap.put(key, value)
|
||||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
m.Clock.put(key, value.Vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range snapshot.Remove {
|
for key, value := range snapshot.Remove {
|
||||||
m.removeMap.put(key, value)
|
m.removeMap.put(key, value)
|
||||||
m.vectors[key] = max(value.Vector, m.vectors[key])
|
m.Clock.put(key, value.Vector)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.lock.Unlock()
|
func (m *TwoPhaseMap[K, D]) Prune() {
|
||||||
|
m.addMap.Prune()
|
||||||
|
m.removeMap.Prune()
|
||||||
|
m.Clock.Prune()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
|
// NewTwoPhaseMap: create a new two phase map. Consists of two maps
|
||||||
// a grow map and a remove map. If both timestamps equal then favour keeping
|
// a grow map and a remove map. If both timestamps equal then favour keeping
|
||||||
// it in the map
|
// it in the map
|
||||||
func NewTwoPhaseMap[K comparable, D any](processId K) *TwoPhaseMap[K, D] {
|
func NewTwoPhaseMap[K cmp.Ordered, D any](processId K, hashKey func(K) uint64, staleTime uint64) *TwoPhaseMap[K, D] {
|
||||||
m := TwoPhaseMap[K, D]{
|
m := TwoPhaseMap[K, D]{
|
||||||
vectors: make(map[K]uint64),
|
|
||||||
processId: processId,
|
processId: processId,
|
||||||
|
Clock: NewVectorClock(processId, hashKey, staleTime),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMap = NewGMap[K, D](m.incrementClock)
|
m.addMap = NewGMap[K, D](m.Clock)
|
||||||
m.removeMap = NewGMap[K, bool](m.incrementClock)
|
m.removeMap = NewGMap[K, bool](m.Clock)
|
||||||
return &m
|
return &m
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,8 @@ import (
|
|||||||
type SyncState int
|
type SyncState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PREPARE SyncState = iota
|
HASH SyncState = iota
|
||||||
|
PREPARE
|
||||||
PRESENT
|
PRESENT
|
||||||
EXCHANGE
|
EXCHANGE
|
||||||
MERGE
|
MERGE
|
||||||
@ -26,16 +27,54 @@ type TwoPhaseSyncer struct {
|
|||||||
peerMsg []byte
|
peerMsg []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TwoPhaseHash struct {
|
||||||
|
Hash uint64
|
||||||
|
}
|
||||||
|
|
||||||
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
|
type SyncFSM map[SyncState]func(*TwoPhaseSyncer) ([]byte, bool)
|
||||||
|
|
||||||
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
func hash(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
hash := TwoPhaseHash{
|
||||||
|
Hash: syncer.manager.store.Clock.GetHash(),
|
||||||
|
}
|
||||||
|
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
enc := gob.NewEncoder(&buffer)
|
enc := gob.NewEncoder(&buffer)
|
||||||
|
|
||||||
err := enc.Encode(*syncer.mapState)
|
err := enc.Encode(hash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
syncer.IncrementState()
|
||||||
|
return buffer.Bytes(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
||||||
|
var recvBuffer = bytes.NewBuffer(syncer.peerMsg)
|
||||||
|
dec := gob.NewDecoder(recvBuffer)
|
||||||
|
|
||||||
|
var hash TwoPhaseHash
|
||||||
|
err := dec.Decode(&hash)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteErrorf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// If vector clocks are equal then no need to merge state
|
||||||
|
// Helps to reduce bandwidth by detecting early
|
||||||
|
if hash.Hash == syncer.manager.store.Clock.GetHash() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buffer)
|
||||||
|
|
||||||
|
err = enc.Encode(*syncer.mapState)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
syncer.IncrementState()
|
syncer.IncrementState()
|
||||||
@ -54,10 +93,11 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
err := dec.Decode(&mapState)
|
err := dec.Decode(&mapState)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Log.WriteInfof(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
difference := syncer.mapState.Difference(&mapState)
|
difference := syncer.mapState.Difference(&mapState)
|
||||||
|
syncer.manager.store.Clock.Merge(mapState.Vectors)
|
||||||
|
|
||||||
var sendBuffer bytes.Buffer
|
var sendBuffer bytes.Buffer
|
||||||
enc := gob.NewEncoder(&sendBuffer)
|
enc := gob.NewEncoder(&sendBuffer)
|
||||||
@ -100,7 +140,6 @@ func merge(syncer *TwoPhaseSyncer) ([]byte, bool) {
|
|||||||
dec.Decode(&snapshot)
|
dec.Decode(&snapshot)
|
||||||
|
|
||||||
syncer.manager.store.Merge(snapshot)
|
syncer.manager.store.Merge(snapshot)
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,10 +164,14 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
|
|||||||
|
|
||||||
func (t *TwoPhaseSyncer) Complete() {
|
func (t *TwoPhaseSyncer) Complete() {
|
||||||
logging.Log.WriteInfof("SYNC COMPLETED")
|
logging.Log.WriteInfof("SYNC COMPLETED")
|
||||||
|
if t.state >= MERGE {
|
||||||
|
t.manager.store.Clock.IncrementClock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
||||||
var generateMessageFsm SyncFSM = SyncFSM{
|
var generateMessageFsm SyncFSM = SyncFSM{
|
||||||
|
HASH: hash,
|
||||||
PREPARE: prepare,
|
PREPARE: prepare,
|
||||||
PRESENT: present,
|
PRESENT: present,
|
||||||
EXCHANGE: exchange,
|
EXCHANGE: exchange,
|
||||||
@ -137,7 +180,7 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
|
|||||||
|
|
||||||
return &TwoPhaseSyncer{
|
return &TwoPhaseSyncer{
|
||||||
manager: manager,
|
manager: manager,
|
||||||
state: PREPARE,
|
state: HASH,
|
||||||
mapState: manager.store.GenerateMessage(),
|
mapState: manager.store.GenerateMessage(),
|
||||||
generateMessageFSM: generateMessageFsm,
|
generateMessageFSM: generateMessageFsm,
|
||||||
}
|
}
|
||||||
|
154
pkg/crdt/vector_clock.go
Normal file
154
pkg/crdt/vector_clock.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package crdt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VectorBucket struct {
|
||||||
|
// clock current value of the node's clock
|
||||||
|
clock uint64
|
||||||
|
// lastUpdate we've seen
|
||||||
|
lastUpdate uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector clock defines an abstract data type
|
||||||
|
// for a vector clock implementation
|
||||||
|
type VectorClock[K cmp.Ordered] struct {
|
||||||
|
vectors map[uint64]*VectorBucket
|
||||||
|
lock sync.RWMutex
|
||||||
|
processID K
|
||||||
|
staleTime uint64
|
||||||
|
hashFunc func(K) uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementClock: increments the node's value in the vector clock
|
||||||
|
func (m *VectorClock[K]) IncrementClock() uint64 {
|
||||||
|
maxClock := uint64(0)
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for _, value := range m.vectors {
|
||||||
|
maxClock = max(maxClock, value.clock)
|
||||||
|
}
|
||||||
|
|
||||||
|
newBucket := VectorBucket{
|
||||||
|
clock: maxClock + 1,
|
||||||
|
lastUpdate: uint64(time.Now().Unix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.vectors[m.hashFunc(m.processID)] = &newBucket
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
return maxClock
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHash: gets the hash of the vector clock used to determine if there
|
||||||
|
// are any changes
|
||||||
|
func (m *VectorClock[K]) GetHash() uint64 {
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
hash := uint64(0)
|
||||||
|
|
||||||
|
for key, bucket := range m.vectors {
|
||||||
|
hash += key * (bucket.clock + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
|
||||||
|
for key, value := range vectors {
|
||||||
|
m.put(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStale: get all entries that are stale within the mesh
|
||||||
|
func (m *VectorClock[K]) getStale() []uint64 {
|
||||||
|
m.lock.RLock()
|
||||||
|
maxTimeStamp := lib.Reduce(0, lib.MapValues(m.vectors), func(i uint64, vb *VectorBucket) uint64 {
|
||||||
|
return max(i, vb.lastUpdate)
|
||||||
|
})
|
||||||
|
|
||||||
|
toRemove := make([]uint64, 0)
|
||||||
|
|
||||||
|
for key, bucket := range m.vectors {
|
||||||
|
if maxTimeStamp-bucket.lastUpdate > m.staleTime {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return toRemove
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Prune() {
|
||||||
|
stale := m.getStale()
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
|
||||||
|
for _, key := range stale {
|
||||||
|
delete(m.vectors, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
lastUpdate := m.vectors[m.hashFunc(m.processID)].lastUpdate
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return lastUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) Put(key K, value uint64) {
|
||||||
|
m.put(m.hashFunc(key), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) put(key uint64, value uint64) {
|
||||||
|
clockValue := uint64(0)
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
bucket, ok := m.vectors[key]
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
clockValue = bucket.clock
|
||||||
|
}
|
||||||
|
|
||||||
|
if value > clockValue {
|
||||||
|
newBucket := VectorBucket{
|
||||||
|
clock: value,
|
||||||
|
lastUpdate: uint64(time.Now().Unix()),
|
||||||
|
}
|
||||||
|
m.vectors[key] = &newBucket
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
|
||||||
|
clock := make(map[uint64]uint64)
|
||||||
|
|
||||||
|
m.lock.RLock()
|
||||||
|
|
||||||
|
for key, value := range m.vectors {
|
||||||
|
clock[key] = value.clock
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.RUnlock()
|
||||||
|
return clock
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVectorClock[K cmp.Ordered](processID K, hashFunc func(K) uint64, staleTime uint64) *VectorClock[K] {
|
||||||
|
return &VectorClock[K]{
|
||||||
|
vectors: make(map[uint64]*VectorBucket),
|
||||||
|
processID: processID,
|
||||||
|
staleTime: staleTime,
|
||||||
|
hashFunc: hashFunc,
|
||||||
|
}
|
||||||
|
}
|
@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
// NewCtrlServerParams are the params requried to create a new ctrl server
|
// NewCtrlServerParams are the params requried to create a new ctrl server
|
||||||
type NewCtrlServerParams struct {
|
type NewCtrlServerParams struct {
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.DaemonConfiguration
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
CtrlProvider rpc.MeshCtrlServerServer
|
CtrlProvider rpc.MeshCtrlServerServer
|
||||||
SyncProvider rpc.SyncServiceServer
|
SyncProvider rpc.SyncServiceServer
|
||||||
@ -28,7 +28,9 @@ type NewCtrlServerParams struct {
|
|||||||
// operation failed
|
// operation failed
|
||||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||||
ctrlServer := new(MeshCtrlServer)
|
ctrlServer := new(MeshCtrlServer)
|
||||||
meshFactory := &crdt.TwoPhaseMapFactory{}
|
meshFactory := &crdt.TwoPhaseMapFactory{
|
||||||
|
Config: params.Conf,
|
||||||
|
}
|
||||||
nodeFactory := &crdt.MeshNodeFactory{
|
nodeFactory := &crdt.MeshNodeFactory{
|
||||||
Config: *params.Conf,
|
Config: *params.Conf,
|
||||||
}
|
}
|
||||||
@ -36,7 +38,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
|||||||
ipAllocator := &ip.ULABuilder{}
|
ipAllocator := &ip.ULABuilder{}
|
||||||
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
|
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
|
||||||
|
|
||||||
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
|
configApplyer := mesh.NewWgMeshConfigApplyer()
|
||||||
|
|
||||||
meshManagerParams := &mesh.NewMeshManagerParams{
|
meshManagerParams := &mesh.NewMeshManagerParams{
|
||||||
Conf: *params.Conf,
|
Conf: *params.Conf,
|
||||||
@ -87,7 +89,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
|||||||
return ctrlServer, nil
|
return ctrlServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration {
|
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
|
||||||
return s.Conf
|
return s.Conf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ type Mesh struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CtrlServer interface {
|
type CtrlServer interface {
|
||||||
GetConfiguration() *conf.WgMeshConfiguration
|
GetConfiguration() *conf.DaemonConfiguration
|
||||||
GetClient() *wgctrl.Client
|
GetClient() *wgctrl.Client
|
||||||
GetQuerier() query.Querier
|
GetQuerier() query.Querier
|
||||||
GetMeshManager() mesh.MeshManager
|
GetMeshManager() mesh.MeshManager
|
||||||
@ -48,6 +48,6 @@ type MeshCtrlServer struct {
|
|||||||
MeshManager mesh.MeshManager
|
MeshManager mesh.MeshManager
|
||||||
ConnectionManager conn.ConnectionManager
|
ConnectionManager conn.ConnectionManager
|
||||||
ConnectionServer *conn.ConnectionServer
|
ConnectionServer *conn.ConnectionServer
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.DaemonConfiguration
|
||||||
Querier query.Querier
|
Querier query.Querier
|
||||||
}
|
}
|
||||||
|
@ -23,10 +23,10 @@ func NewCtrlServerStub() *CtrlServerStub {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration {
|
func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration {
|
||||||
return &conf.WgMeshConfiguration{
|
return &conf.DaemonConfiguration{
|
||||||
GrpcPort: "8080",
|
GrpcPort: 8080,
|
||||||
Endpoint: "abc.com",
|
BaseConfiguration: conf.WgConfiguration{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ type NewMeshArgs struct {
|
|||||||
// Endpoint is the routable alias of the machine. Can be an IP
|
// Endpoint is the routable alias of the machine. Can be an IP
|
||||||
// or DNS entry
|
// or DNS entry
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
Role string
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinMeshArgs struct {
|
type JoinMeshArgs struct {
|
||||||
@ -25,12 +26,12 @@ type JoinMeshArgs struct {
|
|||||||
IpAdress string
|
IpAdress string
|
||||||
// Port is the WireGuard port to expose
|
// Port is the WireGuard port to expose
|
||||||
Port int
|
Port int
|
||||||
// Endpoint is the routable address of this machine. If not provided
|
// Endpoint to use to override the default
|
||||||
// defaults to the default address
|
|
||||||
Endpoint string
|
Endpoint string
|
||||||
// Client specifies whether we should join as a client of the peer
|
// Client specifies whether we should join as a client of the peer
|
||||||
// we are connecting to
|
// we are connecting to
|
||||||
Client bool
|
Client bool
|
||||||
|
Role string
|
||||||
}
|
}
|
||||||
|
|
||||||
type PutServiceArgs struct {
|
type PutServiceArgs struct {
|
||||||
|
@ -1,11 +1,34 @@
|
|||||||
package lib
|
package lib
|
||||||
|
|
||||||
|
import "cmp"
|
||||||
|
|
||||||
// MapToSlice converts a map to a slice in go
|
// MapToSlice converts a map to a slice in go
|
||||||
func MapValues[K comparable, V any](m map[K]V) []V {
|
func MapValues[K cmp.Ordered, V any](m map[K]V) []V {
|
||||||
return MapValuesWithExclude(m, map[K]struct{}{})
|
return MapValuesWithExclude(m, map[K]struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}) []V {
|
type MapItemsEntry[K cmp.Ordered, V any] struct {
|
||||||
|
Key K
|
||||||
|
Value V
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapItems[K cmp.Ordered, V any](m map[K]V) []MapItemsEntry[K, V] {
|
||||||
|
keys := MapKeys(m)
|
||||||
|
values := MapValues(m)
|
||||||
|
|
||||||
|
vs := make([]MapItemsEntry[K, V], len(keys))
|
||||||
|
|
||||||
|
for index, _ := range keys {
|
||||||
|
vs[index] = MapItemsEntry[K, V]{
|
||||||
|
Key: keys[index],
|
||||||
|
Value: values[index],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapValuesWithExclude[K cmp.Ordered, V any](m map[K]V, exclude map[K]struct{}) []V {
|
||||||
values := make([]V, len(m)-len(exclude))
|
values := make([]V, len(m)-len(exclude))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
@ -26,7 +49,7 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
|
|||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapKeys[K comparable, V any](m map[K]V) []K {
|
func MapKeys[K cmp.Ordered, V any](m map[K]V) []K {
|
||||||
values := make([]K, len(m))
|
values := make([]K, len(m))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
|
@ -140,26 +140,38 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
|
|||||||
family = unix.AF_INET
|
family = unix.AF_INET
|
||||||
}
|
}
|
||||||
|
|
||||||
attr := rtnetlink.RouteAttributes{
|
routes, err := c.listRoutes(ifName, family)
|
||||||
Dst: dst.IP,
|
|
||||||
OutIface: uint32(iface.Index),
|
|
||||||
Gateway: gw,
|
|
||||||
}
|
|
||||||
|
|
||||||
ones, _ := dst.Mask.Size()
|
|
||||||
|
|
||||||
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
|
||||||
Family: family,
|
|
||||||
Table: unix.RT_TABLE_MAIN,
|
|
||||||
Protocol: unix.RTPROT_BOOT,
|
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
|
||||||
Type: unix.RTN_UNICAST,
|
|
||||||
DstLength: uint8(ones),
|
|
||||||
Attributes: attr,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add route %w", err)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it already exists no need to add the route
|
||||||
|
if !Contains(routes, func(prevRoute rtnetlink.RouteMessage) bool {
|
||||||
|
return prevRoute.Attributes.Dst.Equal(route.Destination.IP) &&
|
||||||
|
prevRoute.Attributes.Gateway.Equal(route.Gateway)
|
||||||
|
}) {
|
||||||
|
attr := rtnetlink.RouteAttributes{
|
||||||
|
Dst: dst.IP,
|
||||||
|
OutIface: uint32(iface.Index),
|
||||||
|
Gateway: gw,
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := dst.Mask.Size()
|
||||||
|
|
||||||
|
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
|
||||||
|
Family: family,
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
Protocol: unix.RTPROT_BOOT,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
DstLength: uint8(ones),
|
||||||
|
Attributes: attr,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -248,6 +260,14 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
|||||||
if route.equal(r) {
|
if route.equal(r) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -255,7 +275,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
|
|||||||
toDelete := Filter(ifRoutes, shouldExclude)
|
toDelete := Filter(ifRoutes, shouldExclude)
|
||||||
|
|
||||||
for _, route := range toDelete {
|
for _, route := range toDelete {
|
||||||
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
|
logging.Log.WriteInfof("Deleting route: %s", route.Destination.String())
|
||||||
err := c.DeleteRoute(ifName, route)
|
err := c.DeleteRoute(ifName, route)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
40
pkg/lib/stats.go
Normal file
40
pkg/lib/stats.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// lib contains helper functions for the implementation
|
||||||
|
package lib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"gonum.org/v1/gonum/stat"
|
||||||
|
"gonum.org/v1/gonum/stat/distuv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Modelling the distribution using a normal distribution get the count
|
||||||
|
// of the outliers
|
||||||
|
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
|
||||||
|
n := float64(len(counts))
|
||||||
|
|
||||||
|
keys := MapKeys(counts)
|
||||||
|
values := make([]float64, len(keys))
|
||||||
|
|
||||||
|
for index, key := range keys {
|
||||||
|
values[index] = float64(counts[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
mean := stat.Mean(values, nil)
|
||||||
|
stdDev := stat.StdDev(values, nil)
|
||||||
|
|
||||||
|
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
|
||||||
|
|
||||||
|
lowerBound := mean - moe
|
||||||
|
|
||||||
|
var outliers []K
|
||||||
|
|
||||||
|
for i, count := range values {
|
||||||
|
if count < lowerBound {
|
||||||
|
outliers = append(outliers, keys[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outliers
|
||||||
|
}
|
@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@ -25,8 +24,8 @@ type MeshConfigApplyer interface {
|
|||||||
// WgMeshConfigApplyer applies WireGuard configuration
|
// WgMeshConfigApplyer applies WireGuard configuration
|
||||||
type WgMeshConfigApplyer struct {
|
type WgMeshConfigApplyer struct {
|
||||||
meshManager MeshManager
|
meshManager MeshManager
|
||||||
config *conf.WgMeshConfiguration
|
|
||||||
routeInstaller route.RouteInstaller
|
routeInstaller route.RouteInstaller
|
||||||
|
hashFunc func(MeshNode) int
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeNode struct {
|
type routeNode struct {
|
||||||
@ -34,49 +33,44 @@ type routeNode struct {
|
|||||||
route Route
|
route Route
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
|
type convertMeshNodeParams struct {
|
||||||
peerToClients map[string][]net.IPNet,
|
node MeshNode
|
||||||
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
|
self MeshNode
|
||||||
|
mesh MeshProvider
|
||||||
|
device *wgtypes.Device
|
||||||
|
peerToClients map[string][]net.IPNet
|
||||||
|
routes map[string][]routeNode
|
||||||
|
}
|
||||||
|
|
||||||
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
|
func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
|
||||||
|
pubKey, err := params.node.GetPublicKey()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := node.GetPublicKey()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedips := make([]net.IPNet, 1)
|
allowedips := make([]net.IPNet, 1)
|
||||||
allowedips[0] = *node.GetWgHost()
|
allowedips[0] = *params.node.GetWgHost()
|
||||||
|
|
||||||
clients, ok := peerToClients[pubKey.String()]
|
clients, ok := params.peerToClients[pubKey.String()]
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
allowedips = append(allowedips, clients...)
|
allowedips = append(allowedips, clients...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range node.GetRoutes() {
|
for _, route := range params.node.GetRoutes() {
|
||||||
bestRoutes := routes[route.GetDestination().String()]
|
bestRoutes := params.routes[route.GetDestination().String()]
|
||||||
var pickedRoute routeNode
|
var pickedRoute routeNode
|
||||||
|
|
||||||
if len(bestRoutes) == 1 {
|
if len(bestRoutes) == 1 {
|
||||||
pickedRoute = bestRoutes[0]
|
pickedRoute = bestRoutes[0]
|
||||||
} else if len(bestRoutes) > 1 {
|
} else if len(bestRoutes) > 1 {
|
||||||
keyFunc := func(mn MeshNode) int {
|
|
||||||
pubKey, _ := mn.GetPublicKey()
|
|
||||||
return lib.HashString(pubKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
bucketFunc := func(rn routeNode) int {
|
bucketFunc := func(rn routeNode) int {
|
||||||
return lib.HashString(rn.gateway)
|
return lib.HashString(rn.gateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Else there is more than one candidate so consistently hash
|
// Else there is more than one candidate so consistently hash
|
||||||
pickedRoute = lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
|
pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pickedRoute.gateway == pubKey.String() {
|
if pickedRoute.gateway == pubKey.String() {
|
||||||
@ -84,15 +78,28 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
config := params.mesh.GetConfiguration()
|
||||||
|
|
||||||
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
|
var keepAlive time.Duration = time.Duration(0)
|
||||||
pubKey, _ := node.GetPublicKey()
|
|
||||||
|
if config.KeepAliveWg != nil {
|
||||||
|
keepAlive = time.Duration(*config.KeepAliveWg) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool {
|
||||||
|
pubKey, _ := params.node.GetPublicKey()
|
||||||
return p.PublicKey.String() == pubKey.String()
|
return p.PublicKey.String() == pubKey.String()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't override the existing IP in case it already exists
|
||||||
if existing != -1 {
|
if existing != -1 {
|
||||||
endpoint = device.Peers[existing].Endpoint
|
endpoint = params.device.Peers[existing].Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConfig := wgtypes.PeerConfig{
|
peerConfig := wgtypes.PeerConfig{
|
||||||
@ -110,13 +117,15 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Dev
|
|||||||
// consistently hash to evenly spread the distribution of traffic
|
// consistently hash to evenly spread the distribution of traffic
|
||||||
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
|
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
|
||||||
mesh, _ := meshProvider.GetMesh()
|
mesh, _ := meshProvider.GetMesh()
|
||||||
|
|
||||||
routes := make(map[string][]routeNode)
|
routes := make(map[string][]routeNode)
|
||||||
|
|
||||||
|
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
|
||||||
|
return p.GetType() == conf.PEER_ROLE
|
||||||
|
})
|
||||||
|
|
||||||
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
|
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
|
||||||
ula := &ip.ULABuilder{}
|
ula := &ip.ULABuilder{}
|
||||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||||
|
|
||||||
return ipNet
|
return ipNet
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -125,6 +134,13 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
|
|
||||||
for _, route := range node.GetRoutes() {
|
for _, route := range node.GetRoutes() {
|
||||||
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
|
||||||
|
v6Default, _, _ := net.ParseCIDR("::/0")
|
||||||
|
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
|
|
||||||
|
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return prefix.Contains(route.GetDestination().IP)
|
return prefix.Contains(route.GetDestination().IP)
|
||||||
}) {
|
}) {
|
||||||
continue
|
continue
|
||||||
@ -138,6 +154,24 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
route: route,
|
route: route,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Client's only acessible by another peer
|
||||||
|
if node.GetType() == conf.CLIENT_ROLE {
|
||||||
|
peer := m.getCorrespondingPeer(peers, node)
|
||||||
|
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId())
|
||||||
|
|
||||||
|
// If the node isn't the self use that peer as the gateway
|
||||||
|
if !NodeEquals(peer, self) {
|
||||||
|
peerPub, _ := peer.GetPublicKey()
|
||||||
|
rn.gateway = peerPub.String()
|
||||||
|
rn.route = &RouteStub{
|
||||||
|
Destination: rn.route.GetDestination(),
|
||||||
|
HopCount: rn.route.GetHopCount() + 1,
|
||||||
|
// Append the path to this peer
|
||||||
|
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
otherRoute = make([]routeNode, 1)
|
otherRoute = make([]routeNode, 1)
|
||||||
otherRoute[0] = rn
|
otherRoute[0] = rn
|
||||||
@ -145,8 +179,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
|
||||||
otherRoute[0] = rn
|
otherRoute[0] = rn
|
||||||
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
|
||||||
logging.Log.WriteInfof("Other Route Hop: %d", otherRoute[0].route.GetHopCount())
|
|
||||||
logging.Log.WriteInfof("Route gateway %s, route hop %d", rn.gateway, route.GetHopCount())
|
|
||||||
routes[destination] = append(otherRoute, rn)
|
routes[destination] = append(otherRoute, rn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -157,67 +189,127 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
|
|||||||
|
|
||||||
// getCorrespondignPeer: gets the peer corresponding to the client
|
// getCorrespondignPeer: gets the peer corresponding to the client
|
||||||
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
|
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
|
||||||
hashFunc := func(mn MeshNode) int {
|
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
|
||||||
pubKey, _ := mn.GetPublicKey()
|
|
||||||
return lib.HashString(pubKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := lib.ConsistentHash(peers, client, hashFunc, hashFunc)
|
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) getClientConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode) (*wgtypes.Config, error) {
|
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
|
||||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
peers := dev.Peers
|
||||||
|
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
|
||||||
|
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
|
||||||
|
return p1.PublicKey.String() == p2.PublicKey.String()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return lib.Map(peers, func(p wgtypes.Peer) wgtypes.PeerConfig {
|
||||||
|
return wgtypes.PeerConfig{
|
||||||
|
PublicKey: p.PublicKey,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetConfigParams struct {
|
||||||
|
mesh MeshProvider
|
||||||
|
peers []MeshNode
|
||||||
|
clients []MeshNode
|
||||||
|
dev *wgtypes.Device
|
||||||
|
routes map[string][]routeNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
|
||||||
|
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||||
|
ula := &ip.ULABuilder{}
|
||||||
|
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
|
||||||
|
|
||||||
|
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
|
||||||
|
return lib.Filter(rns, func(rn routeNode) bool {
|
||||||
|
ip, _, _ := net.ParseCIDR(rn.gateway)
|
||||||
|
return meshNet.Contains(ip)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
|
||||||
|
return *rs[0].route.GetDestination()
|
||||||
|
})
|
||||||
|
routes = append(routes, *meshNet)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := m.getCorrespondingPeer(peers, self)
|
peer := m.getCorrespondingPeer(params.peers, self)
|
||||||
|
|
||||||
pubKey, _ := peer.GetPublicKey()
|
pubKey, _ := peer.GetPublicKey()
|
||||||
|
|
||||||
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
|
config := params.mesh.GetConfiguration()
|
||||||
|
|
||||||
|
keepAlive := time.Duration(*config.KeepAliveWg) * time.Second
|
||||||
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
allowedips := make([]net.IPNet, 1)
|
|
||||||
_, ipnet, _ := net.ParseCIDR("::/0")
|
|
||||||
allowedips[0] = *ipnet
|
|
||||||
|
|
||||||
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
peerCfgs := make([]wgtypes.PeerConfig, 1)
|
||||||
|
|
||||||
peerCfgs[0] = wgtypes.PeerConfig{
|
peerCfgs[0] = wgtypes.PeerConfig{
|
||||||
PublicKey: pubKey,
|
PublicKey: pubKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
AllowedIPs: allowedips,
|
AllowedIPs: routes,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
installedRoutes := make([]lib.Route, 0)
|
||||||
|
|
||||||
|
for _, route := range peerCfgs[0].AllowedIPs {
|
||||||
|
installedRoutes = append(installedRoutes, lib.Route{
|
||||||
|
Gateway: peer.GetWgHost().IP,
|
||||||
|
Destination: route,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerCfgs,
|
Peers: peerCfgs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
||||||
return &cfg, err
|
return &cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode, clients []MeshNode, dev *wgtypes.Device) (*wgtypes.Config, error) {
|
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
|
||||||
|
routes := make([]lib.Route, 0)
|
||||||
|
|
||||||
|
for _, route := range wgNode.AllowedIPs {
|
||||||
|
ula := &ip.ULABuilder{}
|
||||||
|
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
||||||
|
|
||||||
|
_, defaultRoute, _ := net.ParseCIDR("::/0")
|
||||||
|
|
||||||
|
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
|
||||||
|
routes = append(routes, lib.Route{
|
||||||
|
Gateway: node.GetWgHost().IP,
|
||||||
|
Destination: route,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
|
||||||
peerToClients := make(map[string][]net.IPNet)
|
peerToClients := make(map[string][]net.IPNet)
|
||||||
routes := m.getRoutes(mesh)
|
|
||||||
installedRoutes := make([]lib.Route, 0)
|
installedRoutes := make([]lib.Route, 0)
|
||||||
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
peerConfigs := make([]wgtypes.PeerConfig, 0)
|
||||||
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
|
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, n := range clients {
|
for _, n := range params.clients {
|
||||||
if len(peers) > 0 {
|
if len(params.peers) > 0 {
|
||||||
peer := m.getCorrespondingPeer(peers, n)
|
peer := m.getCorrespondingPeer(params.peers, n)
|
||||||
pubKey, _ := peer.GetPublicKey()
|
pubKey, _ := peer.GetPublicKey()
|
||||||
clients, ok := peerToClients[pubKey.String()]
|
clients, ok := peerToClients[pubKey.String()]
|
||||||
|
|
||||||
@ -229,53 +321,56 @@ func (m *WgMeshConfigApplyer) getPeerConfig(mesh MeshProvider, peers []MeshNode,
|
|||||||
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
|
||||||
|
|
||||||
if NodeEquals(self, peer) {
|
if NodeEquals(self, peer) {
|
||||||
cfg, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
cfg, err := m.convertMeshNode(convertMeshNodeParams{
|
||||||
|
node: n,
|
||||||
|
self: self,
|
||||||
|
mesh: params.mesh,
|
||||||
|
device: params.dev,
|
||||||
|
peerToClients: peerToClients,
|
||||||
|
routes: params.routes,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
|
||||||
peerConfigs = append(peerConfigs, *cfg)
|
peerConfigs = append(peerConfigs, *cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, n := range peers {
|
for _, n := range params.peers {
|
||||||
if NodeEquals(n, self) {
|
if NodeEquals(n, self) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
|
peer, err := m.convertMeshNode(convertMeshNodeParams{
|
||||||
|
node: n,
|
||||||
|
self: self,
|
||||||
|
mesh: params.mesh,
|
||||||
|
peerToClients: peerToClients,
|
||||||
|
routes: params.routes,
|
||||||
|
device: params.dev,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range peer.AllowedIPs {
|
installedRoutes = append(installedRoutes, m.getRoutesToInstall(peer, params.mesh, n)...)
|
||||||
ula := &ip.ULABuilder{}
|
|
||||||
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
|
|
||||||
|
|
||||||
if !ipNet.Contains(route.IP) {
|
|
||||||
installedRoutes = append(installedRoutes, lib.Route{
|
|
||||||
Gateway: n.GetWgHost().IP,
|
|
||||||
Destination: route,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConfigs = append(peerConfigs, *peer)
|
peerConfigs = append(peerConfigs, *peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := wgtypes.Config{
|
cfg := wgtypes.Config{
|
||||||
Peers: peerConfigs,
|
Peers: peerConfigs,
|
||||||
ReplacePeers: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
|
err = m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
|
||||||
return &cfg, err
|
return &cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
|
||||||
snap, err := mesh.GetMesh()
|
snap, err := mesh.GetMesh()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -305,17 +400,28 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
|
|
||||||
var cfg *wgtypes.Config = nil
|
var cfg *wgtypes.Config = nil
|
||||||
|
|
||||||
|
configParams := &GetConfigParams{
|
||||||
|
mesh: mesh,
|
||||||
|
peers: peers,
|
||||||
|
clients: clients,
|
||||||
|
dev: dev,
|
||||||
|
routes: routes,
|
||||||
|
}
|
||||||
|
|
||||||
switch self.GetType() {
|
switch self.GetType() {
|
||||||
case conf.PEER_ROLE:
|
case conf.PEER_ROLE:
|
||||||
cfg, err = m.getPeerConfig(mesh, peers, clients, dev)
|
cfg, err = m.getPeerConfig(configParams)
|
||||||
case conf.CLIENT_ROLE:
|
case conf.CLIENT_ROLE:
|
||||||
cfg, err = m.getClientConfig(mesh, peers, clients)
|
cfg, err = m.getClientConfig(configParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toRemove := m.getPeerCfgsToRemove(dev, cfg.Peers)
|
||||||
|
cfg.Peers = append(cfg.Peers, toRemove...)
|
||||||
|
|
||||||
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
err = m.meshManager.GetClient().ConfigureDevice(dev.Name, *cfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -325,9 +431,36 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
|
||||||
|
allRoutes := make(map[string][]routeNode)
|
||||||
|
|
||||||
for _, mesh := range m.meshManager.GetMeshes() {
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
err := m.updateWgConf(mesh)
|
routes := m.getRoutes(mesh)
|
||||||
|
|
||||||
|
for destination, route := range routes {
|
||||||
|
_, ok := allRoutes[destination]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if allRoutes[destination][0].route.GetHopCount() == route[0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = append(allRoutes[destination], route...)
|
||||||
|
} else if route[0].route.GetHopCount() < allRoutes[destination][0].route.GetHopCount() {
|
||||||
|
allRoutes[destination] = route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||||
|
allRoutes := m.getAllRoutes()
|
||||||
|
|
||||||
|
for _, mesh := range m.meshManager.GetMeshes() {
|
||||||
|
err := m.updateWgConf(mesh, allRoutes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -362,9 +495,12 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
|
|||||||
m.meshManager = manager
|
m.meshManager = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
|
func NewWgMeshConfigApplyer() MeshConfigApplyer {
|
||||||
return &WgMeshConfigApplyer{
|
return &WgMeshConfigApplyer{
|
||||||
config: config,
|
|
||||||
routeInstaller: route.NewRouteInstaller(),
|
routeInstaller: route.NewRouteInstaller(),
|
||||||
|
hashFunc: func(mn MeshNode) int {
|
||||||
|
pubKey, _ := mn.GetPublicKey()
|
||||||
|
return lib.HashString(pubKey.String())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,17 +5,17 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/cmd"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MeshManager interface {
|
type MeshManager interface {
|
||||||
CreateMesh(port int) (string, error)
|
CreateMesh(params *CreateMeshParams) (string, error)
|
||||||
AddMesh(params *AddMeshParams) error
|
AddMesh(params *AddMeshParams) error
|
||||||
HasChanges(meshid string) bool
|
HasChanges(meshid string) bool
|
||||||
GetMesh(meshId string) MeshProvider
|
GetMesh(meshId string) MeshProvider
|
||||||
@ -31,7 +31,6 @@ type MeshManager interface {
|
|||||||
UpdateTimeStamp() error
|
UpdateTimeStamp() error
|
||||||
GetClient() *wgctrl.Client
|
GetClient() *wgctrl.Client
|
||||||
GetMeshes() map[string]MeshProvider
|
GetMeshes() map[string]MeshProvider
|
||||||
Prune() error
|
|
||||||
Close() error
|
Close() error
|
||||||
GetMonitor() MeshMonitor
|
GetMonitor() MeshMonitor
|
||||||
GetNode(string, string) MeshNode
|
GetNode(string, string) MeshNode
|
||||||
@ -46,7 +45,7 @@ type MeshManagerImpl struct {
|
|||||||
// HostParameters contains information that uniquely locates
|
// HostParameters contains information that uniquely locates
|
||||||
// the node in the mesh network.
|
// the node in the mesh network.
|
||||||
HostParameters *HostParameters
|
HostParameters *HostParameters
|
||||||
conf *conf.WgMeshConfiguration
|
conf *conf.DaemonConfiguration
|
||||||
meshProviderFactory MeshProviderFactory
|
meshProviderFactory MeshProviderFactory
|
||||||
nodeFactory MeshNodeFactory
|
nodeFactory MeshNodeFactory
|
||||||
configApplyer MeshConfigApplyer
|
configApplyer MeshConfigApplyer
|
||||||
@ -54,6 +53,7 @@ type MeshManagerImpl struct {
|
|||||||
ipAllocator ip.IPAllocator
|
ipAllocator ip.IPAllocator
|
||||||
interfaceManipulator wg.WgInterfaceManipulator
|
interfaceManipulator wg.WgInterfaceManipulator
|
||||||
Monitor MeshMonitor
|
Monitor MeshMonitor
|
||||||
|
cmdRunner cmd.CmdRunner
|
||||||
OnDelete func(MeshProvider)
|
OnDelete func(MeshProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,21 +109,38 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
|
|||||||
return m.Monitor
|
return m.Monitor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune implements MeshManager.
|
// CreateMeshParams contains the parameters required to create a mesh
|
||||||
func (m *MeshManagerImpl) Prune() error {
|
type CreateMeshParams struct {
|
||||||
for _, mesh := range m.Meshes {
|
Port int
|
||||||
err := mesh.Prune(m.conf.PruneTime)
|
Conf *conf.WgConfiguration
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConf: gets the new configuration with the base configuration overriden
|
||||||
|
// from the recent
|
||||||
|
func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) {
|
||||||
|
meshConfiguration := m.conf.BaseConfiguration
|
||||||
|
|
||||||
|
if override != nil {
|
||||||
|
newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
meshConfiguration = newConf
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return &meshConfiguration, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
||||||
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
|
||||||
|
meshConfiguration, err := m.getConf(args.Conf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
meshId, err := m.idGenerator.GetId()
|
meshId, err := m.idGenerator.GetId()
|
||||||
|
|
||||||
var ifName string = ""
|
var ifName string = ""
|
||||||
@ -132,8 +149,10 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...)
|
||||||
|
|
||||||
if !m.conf.StubWg {
|
if !m.conf.StubWg {
|
||||||
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
|
ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error creating mesh: %w", err)
|
return "", fmt.Errorf("error creating mesh: %w", err)
|
||||||
@ -141,12 +160,13 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||||
DevName: ifName,
|
DevName: ifName,
|
||||||
Port: port,
|
Port: args.Port,
|
||||||
Conf: m.conf,
|
Conf: meshConfiguration,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
MeshId: meshId,
|
MeshId: meshId,
|
||||||
NodeID: m.HostParameters.GetPublicKey(),
|
DaemonConf: m.conf,
|
||||||
|
NodeID: m.HostParameters.GetPublicKey(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -156,6 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
|||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
m.Meshes[meshId] = nodeManager
|
m.Meshes[meshId] = nodeManager
|
||||||
m.lock.Unlock()
|
m.lock.Unlock()
|
||||||
|
|
||||||
|
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
|
||||||
|
|
||||||
return meshId, nil
|
return meshId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,6 +186,7 @@ type AddMeshParams struct {
|
|||||||
MeshId string
|
MeshId string
|
||||||
WgPort int
|
WgPort int
|
||||||
MeshBytes []byte
|
MeshBytes []byte
|
||||||
|
Conf *conf.WgConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMesh: Add the mesh to the list of meshes
|
// AddMesh: Add the mesh to the list of meshes
|
||||||
@ -170,6 +194,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
|||||||
var ifName string
|
var ifName string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
meshConfiguration, err := m.getConf(params.Conf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cmdRunner.RunCommands(meshConfiguration.PreUp...)
|
||||||
|
|
||||||
if !m.conf.StubWg {
|
if !m.conf.StubWg {
|
||||||
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
|
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
|
||||||
|
|
||||||
@ -179,14 +211,17 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||||
DevName: ifName,
|
DevName: ifName,
|
||||||
Port: params.WgPort,
|
Port: params.WgPort,
|
||||||
Conf: m.conf,
|
Conf: meshConfiguration,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
NodeID: m.HostParameters.GetPublicKey(),
|
DaemonConf: m.conf,
|
||||||
|
NodeID: m.HostParameters.GetPublicKey(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
m.cmdRunner.RunCommands(meshConfiguration.PostUp...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -256,10 +291,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
|
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
|
||||||
PublicKey: &pubKey,
|
PublicKey: &pubKey,
|
||||||
NodeIP: nodeIP,
|
NodeIP: nodeIP,
|
||||||
WgPort: params.WgPort,
|
WgPort: params.WgPort,
|
||||||
Endpoint: params.Endpoint,
|
Endpoint: params.Endpoint,
|
||||||
|
MeshConfig: mesh.GetConfiguration(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if !s.conf.StubWg {
|
if !s.conf.StubWg {
|
||||||
@ -277,7 +313,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.Meshes[params.MeshId].AddNode(node)
|
s.Meshes[params.MeshId].AddNode(node)
|
||||||
return s.RouteManager.UpdateRoutes()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LeaveMesh leaves the mesh network
|
// LeaveMesh leaves the mesh network
|
||||||
@ -288,10 +324,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
|||||||
return fmt.Errorf("mesh %s does not exist", meshId)
|
return fmt.Errorf("mesh %s does not exist", meshId)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
||||||
|
|
||||||
s.RouteManager.RemoveRoutes(meshId)
|
|
||||||
err = mesh.RemoveNode(s.HostParameters.GetPublicKey())
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -305,6 +338,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
|||||||
delete(s.Meshes, meshId)
|
delete(s.Meshes, meshId)
|
||||||
s.lock.Unlock()
|
s.lock.Unlock()
|
||||||
|
|
||||||
|
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
|
||||||
|
|
||||||
if !s.conf.StubWg {
|
if !s.conf.StubWg {
|
||||||
device, err := mesh.GetDevice()
|
device, err := mesh.GetDevice()
|
||||||
|
|
||||||
@ -319,6 +354,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +366,6 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
|
|||||||
return nil, fmt.Errorf("mesh %s does not exist", meshId)
|
return nil, fmt.Errorf("mesh %s does not exist", meshId)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
|
|
||||||
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
|
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -441,7 +477,7 @@ func (s *MeshManagerImpl) Close() error {
|
|||||||
|
|
||||||
// NewMeshManagerParams params required to create an instance of a mesh manager
|
// NewMeshManagerParams params required to create an instance of a mesh manager
|
||||||
type NewMeshManagerParams struct {
|
type NewMeshManagerParams struct {
|
||||||
Conf conf.WgMeshConfiguration
|
Conf conf.DaemonConfiguration
|
||||||
Client *wgctrl.Client
|
Client *wgctrl.Client
|
||||||
MeshProvider MeshProviderFactory
|
MeshProvider MeshProviderFactory
|
||||||
NodeFactory MeshNodeFactory
|
NodeFactory MeshNodeFactory
|
||||||
@ -450,6 +486,7 @@ type NewMeshManagerParams struct {
|
|||||||
InterfaceManipulator wg.WgInterfaceManipulator
|
InterfaceManipulator wg.WgInterfaceManipulator
|
||||||
ConfigApplyer MeshConfigApplyer
|
ConfigApplyer MeshConfigApplyer
|
||||||
RouteManager RouteManager
|
RouteManager RouteManager
|
||||||
|
CommandRunner cmd.CmdRunner
|
||||||
OnDelete func(MeshProvider)
|
OnDelete func(MeshProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,7 +510,11 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
|
|||||||
m.RouteManager = params.RouteManager
|
m.RouteManager = params.RouteManager
|
||||||
|
|
||||||
if m.RouteManager == nil {
|
if m.RouteManager == nil {
|
||||||
m.RouteManager = NewRouteManager(m)
|
m.RouteManager = NewRouteManager(m, ¶ms.Conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.CommandRunner == nil {
|
||||||
|
m.cmdRunner = &cmd.UnixCmdRunner{}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.idGenerator = params.IdGenerator
|
m.idGenerator = params.IdGenerator
|
||||||
|
@ -9,16 +9,9 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getMeshConfiguration() *conf.WgMeshConfiguration {
|
func getMeshConfiguration() *conf.DaemonConfiguration {
|
||||||
return &conf.WgMeshConfiguration{
|
return &conf.DaemonConfiguration{
|
||||||
GrpcPort: "8080",
|
GrpcPort: 8080,
|
||||||
Endpoint: "abc.com",
|
|
||||||
ClusterSize: 64,
|
|
||||||
SyncRate: 4,
|
|
||||||
BranchRate: 3,
|
|
||||||
InterClusterChance: 0.15,
|
|
||||||
InfectionCount: 2,
|
|
||||||
KeepAliveTime: 60,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
package mesh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pruneFunction(m MeshManager) lib.TimerFunc {
|
|
||||||
return func() error {
|
|
||||||
return m.Prune()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
|
|
||||||
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
|
|
||||||
}
|
|
@ -1,23 +1,25 @@
|
|||||||
package mesh
|
package mesh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteManager interface {
|
type RouteManager interface {
|
||||||
UpdateRoutes() error
|
UpdateRoutes() error
|
||||||
RemoveRoutes(meshId string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteManagerImpl struct {
|
type RouteManagerImpl struct {
|
||||||
meshManager MeshManager
|
meshManager MeshManager
|
||||||
|
conf *conf.DaemonConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||||
meshes := r.meshManager.GetMeshes()
|
meshes := r.meshManager.GetMeshes()
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
routes := make(map[string][]Route)
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
for _, mesh1 := range meshes {
|
||||||
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
|
||||||
@ -26,68 +28,84 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, err := self.GetPublicKey()
|
if _, ok := routes[mesh1.GetMeshId()]; !ok {
|
||||||
|
routes[mesh1.GetMeshId()] = make([]Route, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
routeMap, err := mesh1.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := mesh1.GetRoutes(pubKey.String())
|
if *r.conf.BaseConfiguration.AdvertiseDefaultRoute {
|
||||||
|
_, ipv6Default, _ := net.ParseCIDR("::/0")
|
||||||
|
|
||||||
if err != nil {
|
mesh1.AddRoutes(NodeID(self),
|
||||||
return err
|
&RouteStub{
|
||||||
|
Destination: ipv6Default,
|
||||||
|
HopCount: 0,
|
||||||
|
Path: make([]string, 0),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, mesh2 := range meshes {
|
for _, mesh2 := range meshes {
|
||||||
|
routeValues, ok := routes[mesh2.GetMeshId()]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
routeValues = make([]Route, 0)
|
||||||
|
}
|
||||||
|
|
||||||
if mesh1 == mesh2 {
|
if mesh1 == mesh2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(mesh2.GetMeshId())
|
mesh1IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh1.GetMeshId())
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, &RouteStub{
|
||||||
logging.Log.WriteErrorf(err.Error())
|
Destination: mesh1IpNet,
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
|
|
||||||
Destination: ipNet,
|
|
||||||
HopCount: 0,
|
HopCount: 0,
|
||||||
Path: make([]string, 0),
|
Path: []string{mesh1.GetMeshId()},
|
||||||
})...)
|
})
|
||||||
|
|
||||||
if err != nil {
|
routeValues = append(routeValues, lib.MapValues(routeMap)...)
|
||||||
return err
|
mesh2IpNet, _ := (&ip.ULABuilder{}).GetIPNet(mesh2.GetMeshId())
|
||||||
|
routeValues = lib.Filter(routeValues, func(r Route) bool {
|
||||||
|
pathNotMesh := func(s string) bool {
|
||||||
|
return s == mesh2.GetMeshId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that the route does not see it's own IP
|
||||||
|
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh)
|
||||||
|
})
|
||||||
|
|
||||||
|
routes[mesh2.GetMeshId()] = routeValues
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the set different of each, working out routes to remove and to keep.
|
||||||
|
for meshId, meshRoutes := range routes {
|
||||||
|
mesh := r.meshManager.GetMesh(meshId)
|
||||||
|
self, _ := r.meshManager.GetSelf(meshId)
|
||||||
|
toRemove := make([]Route, 0)
|
||||||
|
|
||||||
|
prevRoutes, _ := mesh.GetRoutes(NodeID(self))
|
||||||
|
|
||||||
|
for _, route := range prevRoutes {
|
||||||
|
if !lib.Contains(meshRoutes, func(r Route) bool {
|
||||||
|
return RouteEquals(r, route)
|
||||||
|
}) {
|
||||||
|
toRemove = append(toRemove, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mesh.RemoveRoutes(NodeID(self), toRemove...)
|
||||||
|
mesh.AddRoutes(NodeID(self), meshRoutes...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeRoutes: removes all meshes we are no longer a part of
|
func NewRouteManager(m MeshManager, conf *conf.DaemonConfiguration) RouteManager {
|
||||||
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
|
return &RouteManagerImpl{meshManager: m, conf: conf}
|
||||||
ulaBuilder := new(ip.ULABuilder)
|
|
||||||
meshes := r.meshManager.GetMeshes()
|
|
||||||
|
|
||||||
ipNet, err := ulaBuilder.GetIPNet(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, mesh1 := range meshes {
|
|
||||||
self, err := r.meshManager.GetSelf(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRouteManager(m MeshManager) RouteManager {
|
|
||||||
return &RouteManagerImpl{meshManager: m}
|
|
||||||
}
|
}
|
||||||
|
@ -81,6 +81,16 @@ type MeshProviderStub struct {
|
|||||||
snapshot *MeshSnapshotStub
|
snapshot *MeshSnapshotStub
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfiguration implements MeshProvider.
|
||||||
|
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
|
||||||
|
panic("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark implements MeshProvider.
|
||||||
|
func (*MeshProviderStub) Mark(nodeId string) {
|
||||||
|
panic("unimplemented")
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveNode implements MeshProvider.
|
// RemoveNode implements MeshProvider.
|
||||||
func (*MeshProviderStub) RemoveNode(nodeId string) error {
|
func (*MeshProviderStub) RemoveNode(nodeId string) error {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
@ -117,16 +127,16 @@ func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
|
|||||||
|
|
||||||
// SetAlias implements MeshProvider.
|
// SetAlias implements MeshProvider.
|
||||||
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
|
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
|
||||||
panic("unimplemented")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRoutes implements MeshProvider.
|
// RemoveRoutes implements MeshProvider.
|
||||||
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
|
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune implements MeshProvider.
|
// Prune implements MeshProvider.
|
||||||
func (*MeshProviderStub) Prune(pruneAmount int) error {
|
func (*MeshProviderStub) Prune() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +200,7 @@ func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams)
|
|||||||
}
|
}
|
||||||
|
|
||||||
type StubNodeFactory struct {
|
type StubNodeFactory struct {
|
||||||
Config *conf.WgMeshConfiguration
|
Config *conf.DaemonConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
|
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
|
||||||
@ -269,7 +279,7 @@ func NewMeshManagerStub() MeshManager {
|
|||||||
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
|
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
|
func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) {
|
||||||
return "tim123", nil
|
return "tim123", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ package mesh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@ -19,6 +20,12 @@ type Route interface {
|
|||||||
GetPath() []string
|
GetPath() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RouteEquals(r1, r2 Route) bool {
|
||||||
|
return r1.GetDestination().String() == r2.GetDestination().String() &&
|
||||||
|
r1.GetHopCount() == r2.GetHopCount() &&
|
||||||
|
slices.Equal(r1.GetPath(), r2.GetPath())
|
||||||
|
}
|
||||||
|
|
||||||
type RouteStub struct {
|
type RouteStub struct {
|
||||||
Destination *net.IPNet
|
Destination *net.IPNet
|
||||||
HopCount int
|
HopCount int
|
||||||
@ -71,11 +78,6 @@ func NodeEquals(node1, node2 MeshNode) bool {
|
|||||||
return key1.String() == key2.String()
|
return key1.String() == key2.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RouteEquals(route1, route2 Route) bool {
|
|
||||||
return route1.GetDestination().String() == route2.GetDestination().String() &&
|
|
||||||
route1.GetHopCount() == route2.GetHopCount()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NodeID(node MeshNode) string {
|
func NodeID(node MeshNode) string {
|
||||||
key, _ := node.GetPublicKey()
|
key, _ := node.GetPublicKey()
|
||||||
return key.String()
|
return key.String()
|
||||||
@ -116,7 +118,7 @@ type MeshProvider interface {
|
|||||||
// AddRoutes: adds routes to the given node
|
// AddRoutes: adds routes to the given node
|
||||||
AddRoutes(nodeId string, route ...Route) error
|
AddRoutes(nodeId string, route ...Route) error
|
||||||
// DeleteRoutes: deletes the routes from the node
|
// DeleteRoutes: deletes the routes from the node
|
||||||
RemoveRoutes(nodeId string, route ...string) error
|
RemoveRoutes(nodeId string, route ...Route) error
|
||||||
// GetSyncer: returns the automerge syncer for sync
|
// GetSyncer: returns the automerge syncer for sync
|
||||||
GetSyncer() MeshSyncer
|
GetSyncer() MeshSyncer
|
||||||
// GetNode get a particular not within the mesh
|
// GetNode get a particular not within the mesh
|
||||||
@ -131,15 +133,21 @@ type MeshProvider interface {
|
|||||||
AddService(nodeId, key, value string) error
|
AddService(nodeId, key, value string) error
|
||||||
// RemoveService: removes the service form the node. throws an error if the service does not exist
|
// RemoveService: removes the service form the node. throws an error if the service does not exist
|
||||||
RemoveService(nodeId, key string) error
|
RemoveService(nodeId, key string) error
|
||||||
// Prune: prunes all nodes that have not updated their timestamp in
|
// Prune: prunes all nodes that have not updated their
|
||||||
// pruneAmount seconds
|
// vector clock
|
||||||
Prune(pruneAmount int) error
|
Prune() error
|
||||||
// GetPeers: get a list of contactable peers
|
// GetPeers: get a list of contactable peers
|
||||||
GetPeers() []string
|
GetPeers() []string
|
||||||
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
|
||||||
GetRoutes(targetNode string) (map[string]Route, error)
|
GetRoutes(targetNode string) (map[string]Route, error)
|
||||||
// RemoveNode(): remove the node from the mesh
|
// RemoveNode(): remove the node from the mesh
|
||||||
RemoveNode(nodeId string) error
|
RemoveNode(nodeId string) error
|
||||||
|
// Mark: marks the node as unreachable. This is not broadcast to the entire
|
||||||
|
// this is not considered when syncing node state
|
||||||
|
Mark(nodeId string)
|
||||||
|
// GetConfiguration: gets the configuration parameters specific for this
|
||||||
|
// mesh network
|
||||||
|
GetConfiguration() *conf.WgConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostParameters contains the IDs of a node
|
// HostParameters contains the IDs of a node
|
||||||
@ -154,12 +162,13 @@ func (h *HostParameters) GetPublicKey() string {
|
|||||||
|
|
||||||
// MeshProviderFactoryParams parameters required to build a mesh provider
|
// MeshProviderFactoryParams parameters required to build a mesh provider
|
||||||
type MeshProviderFactoryParams struct {
|
type MeshProviderFactoryParams struct {
|
||||||
DevName string
|
DevName string
|
||||||
MeshId string
|
MeshId string
|
||||||
Port int
|
Port int
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.WgConfiguration
|
||||||
Client *wgctrl.Client
|
DaemonConf *conf.DaemonConfiguration
|
||||||
NodeID string
|
Client *wgctrl.Client
|
||||||
|
NodeID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// MeshProviderFactory creates an instance of a mesh provider
|
// MeshProviderFactory creates an instance of a mesh provider
|
||||||
@ -170,10 +179,11 @@ type MeshProviderFactory interface {
|
|||||||
// MeshNodeFactoryParams are the parameters required to construct
|
// MeshNodeFactoryParams are the parameters required to construct
|
||||||
// a mesh node
|
// a mesh node
|
||||||
type MeshNodeFactoryParams struct {
|
type MeshNodeFactoryParams struct {
|
||||||
PublicKey *wgtypes.Key
|
PublicKey *wgtypes.Key
|
||||||
NodeIP net.IP
|
NodeIP net.IP
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
MeshConfig *conf.WgConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
// MeshBuilder build the hosts mesh node for it to be added to the mesh
|
// MeshBuilder build the hosts mesh node for it to be added to the mesh
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
"github.com/tim-beatham/wgmesh/pkg/ipc"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
@ -21,7 +22,25 @@ type IpcHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||||
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
|
overrideConf := &conf.WgConfiguration{}
|
||||||
|
|
||||||
|
if args.Role != "" {
|
||||||
|
role := conf.NodeType(args.Role)
|
||||||
|
overrideConf.Role = &role
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.Endpoint != "" {
|
||||||
|
overrideConf.Endpoint = &args.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
if *overrideConf.Role == conf.CLIENT_ROLE {
|
||||||
|
return fmt.Errorf("cannot create a mesh with no public endpoint")
|
||||||
|
}
|
||||||
|
|
||||||
|
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
|
||||||
|
Port: args.WgPort,
|
||||||
|
Conf: overrideConf,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -45,7 +64,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
|||||||
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
|
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() {
|
for meshId := range n.Server.GetMeshManager().GetMeshes() {
|
||||||
meshNames[i] = meshId
|
meshNames[i] = meshId
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@ -55,6 +74,17 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
||||||
|
overrideConf := &conf.WgConfiguration{}
|
||||||
|
|
||||||
|
if args.Role != "" {
|
||||||
|
role := conf.NodeType(args.Role)
|
||||||
|
overrideConf.Role = &role
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.Endpoint != "" {
|
||||||
|
overrideConf.Endpoint = &args.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
|
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -88,6 +118,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
|||||||
MeshId: args.MeshId,
|
MeshId: args.MeshId,
|
||||||
WgPort: args.Port,
|
WgPort: args.Port,
|
||||||
MeshBytes: meshReply.Mesh,
|
MeshBytes: meshReply.Mesh,
|
||||||
|
Conf: overrideConf,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Syncer: picks random nodes from the mesh
|
// Syncer: picks random nodes from the meshs
|
||||||
type Syncer interface {
|
type Syncer interface {
|
||||||
Sync(meshId string) error
|
Sync(meshId string) error
|
||||||
SyncMeshes() error
|
SyncMeshes() error
|
||||||
@ -24,72 +24,101 @@ type SyncerImpl struct {
|
|||||||
infectionCount int
|
infectionCount int
|
||||||
syncCount int
|
syncCount int
|
||||||
cluster conn.ConnCluster
|
cluster conn.ConnCluster
|
||||||
conf *conf.WgMeshConfiguration
|
conf *conf.DaemonConfiguration
|
||||||
|
lastSync uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync: Sync random nodes
|
// Sync: Sync random nodes
|
||||||
func (s *SyncerImpl) Sync(meshId string) error {
|
func (s *SyncerImpl) Sync(meshId string) error {
|
||||||
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
// Self can be nil if the node is removed
|
||||||
|
self, _ := s.manager.GetSelf(meshId)
|
||||||
|
|
||||||
|
correspondingMesh := s.manager.GetMesh(meshId)
|
||||||
|
|
||||||
|
correspondingMesh.Prune()
|
||||||
|
|
||||||
|
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
|
||||||
logging.Log.WriteInfof("No changes for %s", meshId)
|
logging.Log.WriteInfof("No changes for %s", meshId)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof("UPDATING WG CONF")
|
before := time.Now()
|
||||||
|
|
||||||
s.manager.GetRouteManager().UpdateRoutes()
|
s.manager.GetRouteManager().UpdateRoutes()
|
||||||
err := s.manager.ApplyConfig()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logging.Log.WriteInfof("Failed to update config %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKey := s.manager.GetPublicKey()
|
publicKey := s.manager.GetPublicKey()
|
||||||
|
|
||||||
logging.Log.WriteInfof(publicKey.String())
|
logging.Log.WriteInfof(publicKey.String())
|
||||||
|
|
||||||
nodeNames := s.manager.GetMesh(meshId).GetPeers()
|
nodeNames := correspondingMesh.GetPeers()
|
||||||
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
|
||||||
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
|
||||||
|
|
||||||
for _, node := range randomSubset {
|
if self != nil {
|
||||||
logging.Log.WriteInfof("Random node: %s", node)
|
nodeNames = lib.Filter(nodeNames, func(s string) bool {
|
||||||
|
return s != mesh.NodeID(self)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
before := time.Now()
|
var gossipNodes []string
|
||||||
|
|
||||||
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
// Clients always pings its peer for configuration
|
||||||
logging.Log.WriteInfof("Sending to random cluster")
|
if self != nil && self.GetType() == conf.CLIENT_ROLE {
|
||||||
interCluster := s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
keyFunc := lib.HashString
|
||||||
randomSubset = append(randomSubset, interCluster)
|
bucketFunc := lib.HashString
|
||||||
|
|
||||||
|
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc)
|
||||||
|
gossipNodes = make([]string, 1)
|
||||||
|
gossipNodes[0] = neighbour
|
||||||
|
} else {
|
||||||
|
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
|
||||||
|
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
|
||||||
|
|
||||||
|
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
|
||||||
|
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var waitGroup sync.WaitGroup
|
var succeeded bool = false
|
||||||
|
|
||||||
for index := range randomSubset {
|
// Do this synchronously to conserve bandwidth
|
||||||
waitGroup.Add(1)
|
for _, node := range gossipNodes {
|
||||||
|
correspondingPeer := s.manager.GetNode(meshId, node)
|
||||||
|
|
||||||
go func(i int) error {
|
if correspondingPeer == nil {
|
||||||
defer waitGroup.Done()
|
logging.Log.WriteErrorf("node %s does not exist", node)
|
||||||
|
}
|
||||||
|
|
||||||
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
|
err := s.requester.SyncMesh(meshId, correspondingPeer)
|
||||||
|
|
||||||
if correspondingPeer == nil {
|
if err == nil || err == io.EOF {
|
||||||
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
|
succeeded = true
|
||||||
}
|
} else {
|
||||||
|
// If the synchronisation operation has failed them mark a gravestone
|
||||||
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
|
// preventing the peer from being re-contacted until it has updated
|
||||||
return err
|
// itself
|
||||||
}(index)
|
s.manager.GetMesh(meshId).Mark(node)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
waitGroup.Wait()
|
|
||||||
|
|
||||||
s.syncCount++
|
s.syncCount++
|
||||||
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
|
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
|
||||||
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
|
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
|
||||||
|
|
||||||
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
|
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
|
||||||
|
|
||||||
|
if !succeeded {
|
||||||
|
// If could not gossip with anyone then repeat.
|
||||||
|
s.infectionCount++
|
||||||
|
}
|
||||||
|
|
||||||
s.manager.GetMesh(meshId).SaveChanges()
|
s.manager.GetMesh(meshId).SaveChanges()
|
||||||
|
s.lastSync = uint64(time.Now().Unix())
|
||||||
|
|
||||||
|
logging.Log.WriteInfof("UPDATING WG CONF")
|
||||||
|
err := s.manager.ApplyConfig()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteInfof("Failed to update config %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,14 +128,14 @@ func (s *SyncerImpl) SyncMeshes() error {
|
|||||||
err := s.Sync(meshId)
|
err := s.Sync(meshId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer {
|
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer {
|
||||||
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
|
cluster, _ := conn.NewConnCluster(conf.ClusterSize)
|
||||||
return &SyncerImpl{
|
return &SyncerImpl{
|
||||||
manager: m,
|
manager: m,
|
||||||
|
@ -17,31 +17,20 @@ type SyncErrorHandlerImpl struct {
|
|||||||
meshManager mesh.MeshManager
|
meshManager mesh.MeshManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
|
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
|
||||||
mesh := s.meshManager.GetMesh(meshId)
|
mesh := s.meshManager.GetMesh(meshId)
|
||||||
|
mesh.Mark(nodeId)
|
||||||
if mesh == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// self, err := s.meshManager.GetSelf(meshId)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
|
|
||||||
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool {
|
||||||
errStatus, _ := status.FromError(err)
|
errStatus, _ := status.FromError(err)
|
||||||
|
|
||||||
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
|
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
|
||||||
|
|
||||||
switch errStatus.Code() {
|
switch errStatus.Code() {
|
||||||
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
||||||
return s.incrementFailedCount(meshId, endpoint)
|
return s.handleFailed(meshId, nodeId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
// SyncRequester: coordinates the syncing of meshes
|
// SyncRequester: coordinates the syncing of meshes
|
||||||
type SyncRequester interface {
|
type SyncRequester interface {
|
||||||
GetMesh(meshId string, ifName string, port int, endPoint string) error
|
GetMesh(meshId string, ifName string, port int, endPoint string) error
|
||||||
SyncMesh(meshid string, endPoint string) error
|
SyncMesh(meshid string, meshNode mesh.MeshNode) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncRequesterImpl struct {
|
type SyncRequesterImpl struct {
|
||||||
@ -56,8 +56,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error {
|
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
|
||||||
ok := s.errorHdlr.Handle(meshId, endpoint, err)
|
ok := s.errorHdlr.Handle(meshId, pubKey, err)
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
return nil
|
return nil
|
||||||
@ -67,7 +67,10 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SyncMesh: Proactively send a sync request to the other mesh
|
// SyncMesh: Proactively send a sync request to the other mesh
|
||||||
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error {
|
||||||
|
endpoint := meshNode.GetHostEndpoint()
|
||||||
|
pubKey, _ := meshNode.GetPublicKey()
|
||||||
|
|
||||||
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -88,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
|
|
||||||
c := rpc.NewSyncServiceClient(client)
|
c := rpc.NewSyncServiceClient(client)
|
||||||
|
|
||||||
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second)
|
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -96,7 +99,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
err = s.syncMesh(mesh, ctx, c)
|
err = s.syncMesh(mesh, ctx, c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.handleErr(meshId, endpoint, err)
|
return s.handleErr(meshId, pubKey.String(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||||
|
@ -8,10 +8,11 @@ import (
|
|||||||
// Run implements SyncScheduler.
|
// Run implements SyncScheduler.
|
||||||
func syncFunction(syncer Syncer) lib.TimerFunc {
|
func syncFunction(syncer Syncer) lib.TimerFunc {
|
||||||
return func() error {
|
return func() error {
|
||||||
return syncer.SyncMeshes()
|
syncer.SyncMeshes()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
|
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
|
||||||
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
|
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
|
|||||||
logging.Log.WriteInfof("Updated Timestamp")
|
logging.Log.WriteInfof("Updated Timestamp")
|
||||||
return ctrlServer.MeshManager.UpdateTimeStamp()
|
return ctrlServer.MeshManager.UpdateTimeStamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
|
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package wg
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
@ -35,8 +34,7 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
md5 := crypto.MD5.New().Sum(randomBuf)
|
md5 := crypto.MD5.New().Sum(randomBuf)
|
||||||
|
md5Str := fmt.Sprintf("wg%x", md5)[:hashLength]
|
||||||
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
|
|
||||||
|
|
||||||
err = rtnl.CreateLink(md5Str)
|
err = rtnl.CreateLink(md5Str)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user