From fe14f632179735a763027f7b9965e84f816a044c Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Sun, 10 Dec 2023 19:21:54 +0000 Subject: [PATCH] 53-run-commands-pre-up-and-post-down - Ability to run a command pre up and post down - Ability to be a client in one mesh and a peer in the other - Added dev card to specify different sync rate, keepalive rate per mesh. --- cmd/wg-mesh/main.go | 40 +---- cmd/wgmeshd/main.go | 6 +- pkg/automerge/automerge.go | 67 ++------ pkg/automerge/automerge_test.go | 2 +- pkg/automerge/factory.go | 14 +- pkg/cmd/cmd.go | 33 ++++ pkg/conf/conf.go | 281 +++++++++++++++++--------------- pkg/conf/conf_test.go | 117 +------------ pkg/conn/connectionserver.go | 9 +- pkg/crdt/datastore.go | 20 ++- pkg/crdt/factory.go | 33 ++-- pkg/ctrlserver/ctrlserver.go | 10 +- pkg/ctrlserver/ctrltypes.go | 4 +- pkg/ctrlserver/stub.go | 8 +- pkg/ipc/ipc.go | 5 +- pkg/mesh/config.go | 70 +++++--- pkg/mesh/manager.go | 102 ++++++++---- pkg/mesh/manager_test.go | 13 +- pkg/mesh/pruner.go | 16 -- pkg/mesh/route.go | 6 +- pkg/mesh/stub_types.go | 9 +- pkg/mesh/types.go | 25 +-- pkg/robin/requester.go | 35 +++- pkg/sync/syncer.go | 10 +- pkg/sync/syncrequester.go | 2 +- pkg/sync/syncscheduler.go | 2 +- pkg/timers/timers.go | 1 - 27 files changed, 466 insertions(+), 474 deletions(-) create mode 100644 pkg/cmd/cmd.go delete mode 100644 pkg/mesh/pruner.go diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 7462ceb..526beae 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -4,13 +4,9 @@ import ( "fmt" ipcRpc "net/rpc" "os" - "strings" - "time" "github.com/akamensky/argparse" - "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" - "github.com/tim-beatham/wgmesh/pkg/lib" logging "github.com/tim-beatham/wgmesh/pkg/log" ) @@ -20,6 +16,7 @@ type CreateMeshParams struct { Client *ipcRpc.Client WgPort int Endpoint string + Role string } func createMesh(args *CreateMeshParams) string { @@ -27,6 +24,7 @@ func createMesh(args *CreateMeshParams) string { newMeshParams := ipc.NewMeshArgs{ WgPort: args.WgPort, Endpoint: args.Endpoint, + Role: args.Role, } err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) @@ -60,6 +58,7 @@ type JoinMeshParams struct { IfName string WgPort int Endpoint string + Role string } func joinMesh(params *JoinMeshParams) string { @@ -69,6 +68,7 @@ func joinMesh(params *JoinMeshParams) string { MeshId: params.MeshId, IpAdress: params.IpAddress, Port: params.WgPort, + Role: params.Role, } err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply) @@ -80,34 +80,6 @@ func joinMesh(params *JoinMeshParams) string { return reply } -func getMesh(client *ipcRpc.Client, meshId string) { - reply := new(ipc.GetMeshReply) - - err := client.Call("IpcHandler.GetMesh", &meshId, &reply) - - if err != nil { - fmt.Println(err.Error()) - return - } - - for _, node := range reply.Nodes { - fmt.Println("Public Key: " + node.PublicKey) - fmt.Println("Control Endpoint: " + node.HostEndpoint) - fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) - fmt.Println("Wg IP: " + node.WgHost) - fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()) - - mapFunc := func(r ctrlserver.MeshRoute) string { - return r.Destination - } - - advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",") - fmt.Printf("Routes: %s\n", advertiseRoutes) - - fmt.Println("---") - } -} - func leaveMesh(client *ipcRpc.Client, meshId string) { var reply string @@ -255,11 +227,13 @@ func main() { var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) + var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) + var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) @@ -300,6 +274,7 @@ func main() { Client: client, WgPort: *newMeshPort, Endpoint: *newMeshEndpoint, + Role: *newMeshRole, })) } @@ -314,6 +289,7 @@ func main() { IpAddress: *joinMeshIpAddress, MeshId: *joinMeshId, Endpoint: *joinMeshEndpoint, + Role: *joinMeshRole, })) } diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index d495830..95aa9b3 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -19,13 +19,13 @@ import ( func main() { if len(os.Args) != 2 { - logging.Log.WriteErrorf("Need to provide configuration.yaml") + logging.Log.WriteErrorf("Did not provide configuration") return } - conf, err := conf.ParseConfiguration(os.Args[1]) + conf, err := conf.ParseDaemonConfiguration(os.Args[1]) if err != nil { - logging.Log.WriteInfof("Could not parse configuration") + logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error()) return } diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 1e5b580..b62db73 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -24,7 +24,7 @@ type CrdtMeshManager struct { Client *wgctrl.Client doc *automerge.Doc LastHash automerge.ChangeHash - conf *conf.WgMeshConfiguration + conf *conf.WgConfiguration cache *MeshCrdt lastCacheHash automerge.ChangeHash } @@ -74,8 +74,8 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool { return false } - keepAliveTime := timestamp.Int64() - return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime) + return true + // return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime) } func (c *CrdtMeshManager) GetPeers() []string { @@ -135,7 +135,7 @@ type NewCrdtNodeMangerParams struct { MeshId string DevName string Port int - Conf conf.WgMeshConfiguration + Conf *conf.WgConfiguration Client *wgctrl.Client } @@ -146,7 +146,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro manager.doc = automerge.New() manager.IfName = params.DevName manager.Client = params.Client - manager.conf = ¶ms.Conf + manager.conf = params.Conf manager.cache = nil return &manager, nil } @@ -473,59 +473,20 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) erro return err } +// GetConfiguration: gets the configuration for this mesh network +func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration { + return m.conf +} + +// Mark: mark the node as locally dead +func (m *CrdtMeshManager) Mark(nodeId string) { +} + func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { return NewAutomergeSync(m) } func (m *CrdtMeshManager) Prune() error { - // nodes, err := m.doc.Path("nodes").Get() - - // if err != nil { - // return err - // } - - // if nodes.Kind() != automerge.KindMap { - // return errors.New("node must be a map") - // } - - // values, err := nodes.Map().Values() - - // if err != nil { - // return err - // } - - // deletionNodes := make([]string, 0) - - // for nodeId, node := range values { - // if node.Kind() != automerge.KindMap { - // return errors.New("node must be a map") - // } - - // nodeMap := node.Map() - - // timeStamp, err := nodeMap.Get("timestamp") - - // if err != nil { - // return err - // } - - // if timeStamp.Kind() != automerge.KindInt64 { - // return errors.New("timestamp is not int64") - // } - - // timeValue := timeStamp.Int64() - // nowValue := time.Now().Unix() - - // if nowValue-timeValue >= int64(pruneTime) { - // deletionNodes = append(deletionNodes, nodeId) - // } - // } - - // for _, node := range deletionNodes { - // logging.Log.WriteInfof("Pruning %s", node) - // nodes.Map().Delete(node) - // } - return nil } diff --git a/pkg/automerge/automerge_test.go b/pkg/automerge/automerge_test.go index 396b506..4c74e93 100644 --- a/pkg/automerge/automerge_test.go +++ b/pkg/automerge/automerge_test.go @@ -22,7 +22,7 @@ func setUpTests() *TestParams { DevName: "wg0", Port: 5000, Client: nil, - Conf: conf.WgMeshConfiguration{}, + Conf: conf.DaemonConfiguration{}, }) return &TestParams{ diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index 69f9306..2778e57 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -14,13 +14,13 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) return NewCrdtNodeManager(&NewCrdtNodeMangerParams{ MeshId: params.MeshId, DevName: params.DevName, - Conf: *params.Conf, + Conf: params.Conf, Client: params.Client, }) } type MeshNodeFactory struct { - Config conf.WgMeshConfiguration + Config conf.DaemonConfiguration } // Build builds the mesh node that represents the host machine to add @@ -30,7 +30,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) - if f.Config.Role == conf.CLIENT_ROLE { + if *params.MeshConfig.Role == conf.CLIENT_ROLE { grpcEndpoint = "-" } @@ -44,7 +44,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod Routes: make(map[string]Route), Description: "", Alias: "", - Type: string(f.Config.Role), + Type: string(*params.MeshConfig.Role), } } @@ -54,12 +54,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string if params.Endpoint != "" { hostName = params.Endpoint - } else if len(f.Config.Endpoint) != 0 { - hostName = f.Config.Endpoint + } else if len(*params.MeshConfig.Endpoint) != 0 { + hostName = *params.MeshConfig.Endpoint } else { ipFunc := lib.GetPublicIP - if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { + if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { ipFunc = lib.GetOutboundIP } diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go new file mode 100644 index 0000000..1237b65 --- /dev/null +++ b/pkg/cmd/cmd.go @@ -0,0 +1,33 @@ +// cmd is a package for running commands in the different operating systems implementations +package cmd + +import ( + "os/exec" + "strings" +) + +type CmdRunner interface { + RunCommands(commands ...string) error +} + +type UnixCmdRunner struct{} + +// RunCommand: runs the unix command. It splits the command into fields +// and then runs the command accordingly +func RunCommand(cmd string) error { + args := strings.Fields(cmd) + c := exec.Command(args[0], args[1:]...) + return c.Run() +} + +func (l *UnixCmdRunner) RunCommands(commands ...string) error { + for _, cmd := range commands { + err := RunCommand(cmd) + + if err != nil { + return err + } + } + + return nil +} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 9063485..cca0a35 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -4,7 +4,7 @@ package conf import ( "os" - logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/go-playground/validator/v10" "gopkg.in/yaml.v3" ) @@ -30,172 +30,187 @@ const ( DNS_IP_DISCOVERY = "dns" ) -type WgMeshConfiguration struct { +// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can +// tell if the attribute is set +type WgConfiguration struct { + // IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public + // service for IPDiscoverability + IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` + // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes + AdvertiseRoutes *bool `yaml:"advertiseRoutes"` + // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route + // for all nodes to route their packets to + AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"` + // Endpoint contains what value should be set as the public endpoint of this node + Endpoint *string `yaml:"publicEndpoint"` + // Role specifies whether or not the user is globally accessible. + // If the user is globaly accessible they specify themselves as a client. + Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"` + // KeepAliveWg configures the implementation so that we send keep alive packets to peers. + // KeepAlive can only be set if role is type client + KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"` + // PreUp are WireGuard commands to run before adding the WG interface + PreUp []string `yaml:"preUp"` + // PostUp are WireGuard commands to run after adding the WG interface + PostUp []string `yaml:"postUp"` + // PreDown are WireGuard commands to run prior to removing the WG interface + PreDown []string `yaml:"preDown"` + // PostDown are WireGuard command to run after removing the WG interface + PostDown []string `yaml:"postDown"` +} + +type DaemonConfiguration struct { // CertificatePath is the path to the certificate to use in mTLS - CertificatePath string `yaml:"certificatePath"` + CertificatePath string `yaml:"certificatePath" validate:"required,file"` // PrivateKeypath is the path to the clients private key in mTLS - PrivateKeyPath string `yaml:"privateKeyPath"` + PrivateKeyPath string `yaml:"privateKeyPath" validate:"required,file"` // CaCeritifcatePath path to the certificate of the trust certificate authority - CaCertificatePath string `yaml:"caCertificatePath"` + CaCertificatePath string `yaml:"caCertificatePath" validate:"required,file"` // SkipCertVerification specify to skip certificate verification. Should only be used // in test environments SkipCertVerification bool `yaml:"skipCertVerification"` // Port to run the GrpcServer on - GrpcPort string `yaml:"gRPCPort"` - // IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or - // use public IP discovery library - IPDiscovery IPDiscovery `yaml:"ipDiscovery"` - // AdvertiseRoutes advertises other meshes if the node is in multiple meshes - AdvertiseRoutes bool `yaml:"advertiseRoutes"` - // AdvertiseDefaultRoute advertises a default route out of the mesh. - AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"` - // Endpoint is the IP in which this computer is publicly reachable. - // usecase is when the node has multiple IP addresses - Endpoint string `yaml:"publicEndpoint"` - // ClusterSize size of the cluster to split on - ClusterSize int `yaml:"clusterSize"` - // SyncRate number of times per second to perform a sync - SyncRate float64 `yaml:"syncRate"` - // InterClusterChance proability of inter-cluster communication in a sync round - InterClusterChance float64 `yaml:"interClusterChance"` - // BranchRate number of nodes to randomly communicate with - BranchRate int `yaml:"branchRate"` - // InfectionCount number of times we sync before we can no longer catch the udpate - InfectionCount int `yaml:"infectionCount"` - // KeepAliveTime number of seconds before we update node indicating that we are still alive - KeepAliveTime int `yaml:"keepAliveTime"` - // Timeout number of seconds before we consider the node as dead - Timeout int `yaml:"timeout"` - // PruneTime number of seconds before we remove nodes that are likely to be dead - PruneTime int `yaml:"pruneTime"` - // DeadTime: number of seconds before we consider the node as dead and stop considering it - // when picking a random peer - DeadTime int `yaml:"deadTime"` + GrpcPort int `yaml:"gRPCPort" validate:"required"` + // Timeout number of seconds without response that a node is considered unreachable by gRPC + Timeout int `yaml:"timeout" validate:"required,gte=1"` // Profile whether or not to include a http server that profiles the code Profile bool `yaml:"profile"` // StubWg whether or not to stub the WireGuard types StubWg bool `yaml:"stubWg"` - // Role specifies whether or not the user is globally accessible. - // If the user is globaly accessible they specify themselves as a client. - Role NodeType `yaml:"role"` - // KeepAliveWg configures the implementation so that we send keep alive packets to peers. - // KeepAlive can only be set if role is type client - KeepAliveWg int `yaml:"keepAliveWg"` + // SyncRate specifies how long the minimum time should be between synchronisation + SyncRate int `yaml:"syncRate" validate:"required,gte=1"` + // KeepAliveTime: number of seconds before the leader of the mesh sends an update to + // send to every member in the mesh + KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"` + // ClusterSize specifies how many neighbours you should synchronise with per round + ClusterSize int `yaml:"clusterSize" valdiate:"required,gt=0"` + // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round + InterClusterChance float64 `yaml:"interClusterChance" valdiate:"required,gt=0"` + // BranchRate specifies the number of nodes to synchronise with when a node has + // new changes to send to the mesh + BranchRate int `yaml:"branchRate" validate:"required,gte=1"` + // InfectionCount: number of time to sync before an update can no longer be 'caught' + InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"` + // BaseConfiguration base WireGuard configuration to use, this is used when none is provided + BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"` } -func ValidateConfiguration(c *WgMeshConfiguration) error { - if len(c.CertificatePath) == 0 { - return &WgMeshConfigurationError{ - msg: "A public certificate must be specified for mTLS", - } +// ValdiateMeshConfiguration: validates the mesh configuration +func ValidateMeshConfiguration(conf *WgConfiguration) error { + validate := validator.New(validator.WithRequiredStructEnabled()) + err := validate.Struct(conf) + + if conf.PostDown == nil { + conf.PostDown = make([]string, 0) } - if len(c.PrivateKeyPath) == 0 { - return &WgMeshConfigurationError{ - msg: "A private key must be specified for mTLS", - } + if conf.PostUp == nil { + conf.PostUp = make([]string, 0) } - if len(c.CaCertificatePath) == 0 { - return &WgMeshConfigurationError{ - msg: "A ca certificate must be specified for mTLS", - } + if conf.PreDown == nil { + conf.PreDown = make([]string, 0) } - if len(c.GrpcPort) == 0 { - return &WgMeshConfigurationError{ - msg: "A grpc port must be specified", - } + if conf.PreUp == nil { + conf.PreUp = make([]string, 0) } - if c.ClusterSize <= 0 { - return &WgMeshConfigurationError{ - msg: "A cluster size must not be 0", - } - } - - if c.SyncRate <= 0 { - return &WgMeshConfigurationError{ - msg: "SyncRate cannot be negative", - } - } - - if c.BranchRate <= 0 { - return &WgMeshConfigurationError{ - msg: "Branch rate cannot be negative", - } - } - - if c.InfectionCount <= 0 { - return &WgMeshConfigurationError{ - msg: "Infection count cannot be less than 1", - } - } - - if c.KeepAliveTime <= 0 { - return &WgMeshConfigurationError{ - msg: "KeepAliveRate cannot be less than negative", - } - } - - if c.InterClusterChance <= 0 { - return &WgMeshConfigurationError{ - msg: "Intercluster chance cannot be less than 0", - } - } - - if c.Timeout < 1 { - return &WgMeshConfigurationError{ - msg: "Timeout should be greater than or equal to 1", - } - } - - if c.PruneTime < 1 { - return &WgMeshConfigurationError{ - msg: "Prune time cannot be < 1", - } - } - - if c.DeadTime < 1 { - return &WgMeshConfigurationError{ - msg: "Dead time cannot be < 1", - } - } - - if c.KeepAliveTime <= 1 { - return &WgMeshConfigurationError{ - msg: "Prune time cannot be less than keep alive time", - } - } - - if c.Role == "" { - c.Role = PEER_ROLE - } - - if c.IPDiscovery == "" { - c.IPDiscovery = PUBLIC_IP_DISCOVERY - } - - return nil + return err } -// ParseConfiguration parses the mesh configuration -func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { - var conf WgMeshConfiguration +// ValidateDaemonConfiguration: validates the dameon configuration that is used. +func ValidateDaemonConfiguration(c *DaemonConfiguration) error { + validate := validator.New(validator.WithRequiredStructEnabled()) + err := validate.Struct(c) + return err +} + +// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as +// keepalive time, role and so forth. +func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) { + var conf WgConfiguration yamlBytes, err := os.ReadFile(filePath) if err != nil { - logging.Log.WriteErrorf("Read file error: %s\n", err.Error()) return nil, err } err = yaml.Unmarshal(yamlBytes, &conf) if err != nil { - logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error()) return nil, err } - return &conf, ValidateConfiguration(&conf) + return &conf, ValidateMeshConfiguration(&conf) +} + +// ParseDaemonConfiguration parses the mesh configuration and validates the configuration +func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) { + var conf DaemonConfiguration + + yamlBytes, err := os.ReadFile(filePath) + + if err != nil { + return nil, err + } + + err = yaml.Unmarshal(yamlBytes, &conf) + + if err != nil { + return nil, err + } + + return &conf, ValidateDaemonConfiguration(&conf) +} + +// MergemeshConfiguration: merges the configuration in precedence where the last +// element in the list takes the most and the first takes the least +func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) { + var result WgConfiguration + + for _, cfg := range cfgs { + if cfg.AdvertiseDefaultRoute != nil { + result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute + } + + if cfg.AdvertiseRoutes != nil { + result.AdvertiseRoutes = cfg.AdvertiseRoutes + } + + if cfg.Endpoint != nil { + result.Endpoint = cfg.Endpoint + } + + if cfg.IPDiscovery != nil { + result.IPDiscovery = cfg.IPDiscovery + } + + if cfg.KeepAliveWg != nil { + result.KeepAliveWg = cfg.KeepAliveWg + } + + if cfg.PostDown != nil { + result.PostDown = cfg.PostDown + } + + if cfg.PostUp != nil { + result.PostUp = cfg.PostUp + } + + if cfg.PreDown != nil { + result.PreDown = cfg.PreDown + } + + if cfg.PreUp != nil { + result.PreUp = cfg.PreUp + } + + if cfg.Role != nil { + result.Role = cfg.Role + } + } + + return result, ValidateMeshConfiguration(&result) } diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index f6436bf..6facf82 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -2,23 +2,12 @@ package conf import "testing" -func getExampleConfiguration() *WgMeshConfiguration { - return &WgMeshConfiguration{ +func getExampleConfiguration() *DaemonConfiguration { + return &DaemonConfiguration{ CertificatePath: "./cert/cert.pem", PrivateKeyPath: "./cert/key.pem", CaCertificatePath: "./cert/ca.pems", SkipCertVerification: true, - GrpcPort: "8080", - AdvertiseRoutes: true, - Endpoint: "localhost", - ClusterSize: 1, - SyncRate: 1, - InterClusterChance: 0.1, - BranchRate: 2, - KeepAliveTime: 4, - InfectionCount: 1, - Timeout: 2, - PruneTime: 20, } } @@ -26,7 +15,7 @@ func TestConfigurationCertificatePathEmpty(t *testing.T) { conf := getExampleConfiguration() conf.CertificatePath = "" - err := ValidateConfiguration(conf) + err := ValidateDaemonConfiguration(conf) if err == nil { t.Fatal(`error should be thrown`) @@ -37,7 +26,7 @@ func TestConfigurationPrivateKeyPathEmpty(t *testing.T) { conf := getExampleConfiguration() conf.PrivateKeyPath = "" - err := ValidateConfiguration(conf) + err := ValidateDaemonConfiguration(conf) if err == nil { t.Fatal(`error should be thrown`) @@ -48,7 +37,7 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) { conf := getExampleConfiguration() conf.CaCertificatePath = "" - err := ValidateConfiguration(conf) + err := ValidateDaemonConfiguration(conf) if err == nil { t.Fatal(`error should be thrown`) @@ -57,109 +46,21 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) { func TestConfigurationGrpcPortEmpty(t *testing.T) { conf := getExampleConfiguration() - conf.GrpcPort = "" + conf.GrpcPort = 0 - err := ValidateConfiguration(conf) + err := ValidateDaemonConfiguration(conf) if err == nil { t.Fatal(`error should be thrown`) } } -func TestClusterSizeZero(t *testing.T) { - conf := getExampleConfiguration() - conf.ClusterSize = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func SyncRateZero(t *testing.T) { - conf := getExampleConfiguration() - conf.SyncRate = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func BranchRateZero(t *testing.T) { - conf := getExampleConfiguration() - conf.BranchRate = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func InfectionCountZero(t *testing.T) { - conf := getExampleConfiguration() - conf.InfectionCount = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func KeepAliveRateZero(t *testing.T) { - conf := getExampleConfiguration() - conf.KeepAliveTime = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func TestValidCOnfiguration(t *testing.T) { +func TestValidConfiguration(t *testing.T) { conf := getExampleConfiguration() - err := ValidateConfiguration(conf) + err := ValidateDaemonConfiguration(conf) if err != nil { t.Error(err) } } - -func TestTimeout(t *testing.T) { - conf := getExampleConfiguration() - conf.Timeout = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatal(`error should be thrown`) - } -} - -func TestPruneTimeZero(t *testing.T) { - conf := getExampleConfiguration() - conf.PruneTime = 0 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatalf(`Error should be thrown`) - } -} - -func TestPruneTimeLessThanKeepAliveTime(t *testing.T) { - conf := getExampleConfiguration() - conf.PruneTime = 1 - - err := ValidateConfiguration(conf) - - if err == nil { - t.Fatalf(`Error should be thrown`) - } -} diff --git a/pkg/conn/connectionserver.go b/pkg/conn/connectionserver.go index b4d3808..f4e9e3f 100644 --- a/pkg/conn/connectionserver.go +++ b/pkg/conn/connectionserver.go @@ -2,6 +2,7 @@ package conn import ( "crypto/tls" + "fmt" "net" "github.com/tim-beatham/wgmesh/pkg/conf" @@ -21,13 +22,13 @@ type ConnectionServer struct { ctrlProvider rpc.MeshCtrlServerServer // the sync service to synchronise nodes syncProvider rpc.SyncServiceServer - Conf *conf.WgMeshConfiguration + Conf *conf.DaemonConfiguration listener net.Listener } // NewConnectionServerParams contains params for creating a new connection server type NewConnectionServerParams struct { - Conf *conf.WgMeshConfiguration + Conf *conf.DaemonConfiguration CtrlProvider rpc.MeshCtrlServerServer SyncProvider rpc.SyncServiceServer } @@ -76,10 +77,10 @@ func (s *ConnectionServer) Listen() error { rpc.RegisterSyncServiceServer(s.server, s.syncProvider) - lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort)) s.listener = lis - logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort) + logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort) if err != nil { logging.Log.WriteErrorf(err.Error()) diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index bd77585..aae016c 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -154,12 +154,13 @@ func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode { } type TwoPhaseStoreMeshManager struct { - MeshId string - IfName string - Client *wgctrl.Client - LastClock uint64 - conf *conf.WgMeshConfiguration - store *TwoPhaseMap[string, MeshNode] + MeshId string + IfName string + Client *wgctrl.Client + LastClock uint64 + conf *conf.WgConfiguration + daemonConf *conf.DaemonConfiguration + store *TwoPhaseMap[string, MeshNode] } // AddNode() adds a node to the mesh @@ -264,7 +265,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { peerToUpdate := peers[0] - if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) { + if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) { m.store.Mark(peerToUpdate) if len(peers) < 2 { @@ -506,3 +507,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error { m.store.Remove(nodeId) return nil } + +// GetConfiguration implements mesh.MeshProvider. +func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration { + return m.conf +} diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go index 4895bbf..7c53a8e 100644 --- a/pkg/crdt/factory.go +++ b/pkg/crdt/factory.go @@ -9,44 +9,49 @@ import ( "github.com/tim-beatham/wgmesh/pkg/mesh" ) -type TwoPhaseMapFactory struct{} +type TwoPhaseMapFactory struct { + Config *conf.DaemonConfiguration +} func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { return &TwoPhaseStoreMeshManager{ - MeshId: params.MeshId, - IfName: params.DevName, - Client: params.Client, - conf: params.Conf, + MeshId: params.MeshId, + IfName: params.DevName, + Client: params.Client, + conf: params.Conf, + daemonConf: params.DaemonConf, store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { h := fnv.New64a() h.Write([]byte(s)) return h.Sum64() - }, uint64(3*params.Conf.KeepAliveTime)), + }, uint64(3*f.Config.KeepAliveTime)), }, nil } type MeshNodeFactory struct { - Config conf.WgMeshConfiguration + Config conf.DaemonConfiguration } func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { hostName := f.getAddress(params) - grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) + grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort) + wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort) - if f.Config.Role == conf.CLIENT_ROLE { + if *params.MeshConfig.Role == conf.CLIENT_ROLE { grpcEndpoint = "-" + wgEndpoint = "-" } return &MeshNode{ HostEndpoint: grpcEndpoint, PublicKey: params.PublicKey.String(), - WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), + WgEndpoint: wgEndpoint, WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), Routes: make(map[string]Route), Description: "", Alias: "", - Type: string(f.Config.Role), + Type: string(*params.MeshConfig.Role), } } @@ -56,12 +61,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string if params.Endpoint != "" { hostName = params.Endpoint - } else if len(f.Config.Endpoint) != 0 { - hostName = f.Config.Endpoint + } else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 { + hostName = *params.MeshConfig.Endpoint } else { ipFunc := lib.GetPublicIP - if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { + if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { ipFunc = lib.GetOutboundIP } diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 9e54b44..6bfd9dd 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -16,7 +16,7 @@ import ( // NewCtrlServerParams are the params requried to create a new ctrl server type NewCtrlServerParams struct { - Conf *conf.WgMeshConfiguration + Conf *conf.DaemonConfiguration Client *wgctrl.Client CtrlProvider rpc.MeshCtrlServerServer SyncProvider rpc.SyncServiceServer @@ -28,7 +28,9 @@ type NewCtrlServerParams struct { // operation failed func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) - meshFactory := &crdt.TwoPhaseMapFactory{} + meshFactory := &crdt.TwoPhaseMapFactory{ + Config: params.Conf, + } nodeFactory := &crdt.MeshNodeFactory{ Config: *params.Conf, } @@ -36,7 +38,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ipAllocator := &ip.ULABuilder{} interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) - configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf) + configApplyer := mesh.NewWgMeshConfigApplyer() meshManagerParams := &mesh.NewMeshManagerParams{ Conf: *params.Conf, @@ -87,7 +89,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { return ctrlServer, nil } -func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration { +func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration { return s.Conf } diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index 0d03ca6..00f583d 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -34,7 +34,7 @@ type Mesh struct { } type CtrlServer interface { - GetConfiguration() *conf.WgMeshConfiguration + GetConfiguration() *conf.DaemonConfiguration GetClient() *wgctrl.Client GetQuerier() query.Querier GetMeshManager() mesh.MeshManager @@ -48,6 +48,6 @@ type MeshCtrlServer struct { MeshManager mesh.MeshManager ConnectionManager conn.ConnectionManager ConnectionServer *conn.ConnectionServer - Conf *conf.WgMeshConfiguration + Conf *conf.DaemonConfiguration Querier query.Querier } diff --git a/pkg/ctrlserver/stub.go b/pkg/ctrlserver/stub.go index c88851a..d61da6b 100644 --- a/pkg/ctrlserver/stub.go +++ b/pkg/ctrlserver/stub.go @@ -23,10 +23,10 @@ func NewCtrlServerStub() *CtrlServerStub { } } -func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration { - return &conf.WgMeshConfiguration{ - GrpcPort: "8080", - Endpoint: "abc.com", +func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration { + return &conf.DaemonConfiguration{ + GrpcPort: 8080, + BaseConfiguration: conf.WgConfiguration{}, } } diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index b095a7e..00f2b52 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -16,6 +16,7 @@ type NewMeshArgs struct { // Endpoint is the routable alias of the machine. Can be an IP // or DNS entry Endpoint string + Role string } type JoinMeshArgs struct { @@ -25,12 +26,12 @@ type JoinMeshArgs struct { IpAdress string // Port is the WireGuard port to expose Port int - // Endpoint is the routable address of this machine. If not provided - // defaults to the default address + // Endpoint to use to override the default Endpoint string // Client specifies whether we should join as a client of the peer // we are connecting to Client bool + Role string } type PutServiceArgs struct { diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index bddfd3e..6a24b25 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -24,7 +24,6 @@ type MeshConfigApplyer interface { // WgMeshConfigApplyer applies WireGuard configuration type WgMeshConfigApplyer struct { meshManager MeshManager - config *conf.WgMeshConfiguration routeInstaller route.RouteInstaller hashFunc func(MeshNode) int } @@ -34,28 +33,33 @@ type routeNode struct { route Route } -func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, - device *wgtypes.Device, - peerToClients map[string][]net.IPNet, - routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { +type convertMeshNodeParams struct { + node MeshNode + self MeshNode + mesh MeshProvider + device *wgtypes.Device + peerToClients map[string][]net.IPNet + routes map[string][]routeNode +} - pubKey, err := node.GetPublicKey() +func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) { + pubKey, err := params.node.GetPublicKey() if err != nil { return nil, err } allowedips := make([]net.IPNet, 1) - allowedips[0] = *node.GetWgHost() + allowedips[0] = *params.node.GetWgHost() - clients, ok := peerToClients[pubKey.String()] + clients, ok := params.peerToClients[pubKey.String()] if ok { allowedips = append(allowedips, clients...) } - for _, route := range node.GetRoutes() { - bestRoutes := routes[route.GetDestination().String()] + for _, route := range params.node.GetRoutes() { + bestRoutes := params.routes[route.GetDestination().String()] var pickedRoute routeNode if len(bestRoutes) == 1 { @@ -66,7 +70,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, } // Else there is more than one candidate so consistently hash - pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc) + pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc) } if pickedRoute.gateway == pubKey.String() { @@ -74,14 +78,20 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, } } - keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second + config := params.mesh.GetConfiguration() - existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool { - pubKey, _ := node.GetPublicKey() + var keepAlive time.Duration = time.Duration(0) + + if config.KeepAliveWg != nil { + keepAlive = time.Duration(*config.KeepAliveWg) * time.Second + } + + existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool { + pubKey, _ := params.node.GetPublicKey() return p.PublicKey.String() == pubKey.String() }) - endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) + endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint()) if err != nil { return nil, err @@ -89,7 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, // Don't override the existing IP in case it already exists if existing != -1 { - endpoint = device.Peers[existing].Endpoint + endpoint = params.device.Peers[existing].Endpoint } peerConfig := wgtypes.PeerConfig{ @@ -127,7 +137,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][] v6Default, _, _ := net.ParseCIDR("::/0") v4Default, _, _ := net.ParseCIDR("0.0.0.0/0") - if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute { + if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute { return true } @@ -230,7 +240,10 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes peer := m.getCorrespondingPeer(params.peers, self) pubKey, _ := peer.GetPublicKey() - keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second + + config := params.mesh.GetConfiguration() + + keepAlive := time.Duration(*config.KeepAliveWg) * time.Second endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint()) if err != nil { @@ -308,7 +321,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C peerToClients[pubKey.String()] = append(clients, *n.GetWgHost()) if NodeEquals(self, peer) { - cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) + cfg, err := m.convertMeshNode(convertMeshNodeParams{ + node: n, + self: self, + mesh: params.mesh, + device: params.dev, + peerToClients: peerToClients, + routes: params.routes, + }) if err != nil { return nil, err @@ -325,7 +345,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C continue } - peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) + peer, err := m.convertMeshNode(convertMeshNodeParams{ + node: n, + self: self, + mesh: params.mesh, + peerToClients: peerToClients, + routes: params.routes, + device: params.dev, + }) if err != nil { return nil, err @@ -468,9 +495,8 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) { m.meshManager = manager } -func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer { +func NewWgMeshConfigApplyer() MeshConfigApplyer { return &WgMeshConfigApplyer{ - config: config, routeInstaller: route.NewRouteInstaller(), hashFunc: func(mn MeshNode) int { pubKey, _ := mn.GetPublicKey() diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 8fe4d33..26d8a9d 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/tim-beatham/wgmesh/pkg/cmd" "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/lib" @@ -14,7 +15,7 @@ import ( ) type MeshManager interface { - CreateMesh(port int) (string, error) + CreateMesh(params *CreateMeshParams) (string, error) AddMesh(params *AddMeshParams) error HasChanges(meshid string) bool GetMesh(meshId string) MeshProvider @@ -30,7 +31,6 @@ type MeshManager interface { UpdateTimeStamp() error GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider - Prune() error Close() error GetMonitor() MeshMonitor GetNode(string, string) MeshNode @@ -45,7 +45,7 @@ type MeshManagerImpl struct { // HostParameters contains information that uniquely locates // the node in the mesh network. HostParameters *HostParameters - conf *conf.WgMeshConfiguration + conf *conf.DaemonConfiguration meshProviderFactory MeshProviderFactory nodeFactory MeshNodeFactory configApplyer MeshConfigApplyer @@ -53,6 +53,7 @@ type MeshManagerImpl struct { ipAllocator ip.IPAllocator interfaceManipulator wg.WgInterfaceManipulator Monitor MeshMonitor + cmdRunner cmd.CmdRunner OnDelete func(MeshProvider) } @@ -108,21 +109,38 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor { return m.Monitor } -// Prune implements MeshManager. -func (m *MeshManagerImpl) Prune() error { - for _, mesh := range m.Meshes { - err := mesh.Prune() +// CreateMeshParams contains the parameters required to create a mesh +type CreateMeshParams struct { + Port int + Conf *conf.WgConfiguration +} + +// getConf: gets the new configuration with the base configuration overriden +// from the recent +func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) { + meshConfiguration := m.conf.BaseConfiguration + + if override != nil { + newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override) if err != nil { - return err + return nil, err } + + meshConfiguration = newConf } - return nil + return &meshConfiguration, nil } // CreateMesh: Creates a new mesh, stores it and returns the mesh id -func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { +func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) { + meshConfiguration, err := m.getConf(args.Conf) + + if err != nil { + return "", err + } + meshId, err := m.idGenerator.GetId() var ifName string = "" @@ -131,8 +149,10 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { return "", err } + m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...) + if !m.conf.StubWg { - ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey) + ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey) if err != nil { return "", fmt.Errorf("error creating mesh: %w", err) @@ -140,12 +160,13 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { } nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ - DevName: ifName, - Port: port, - Conf: m.conf, - Client: m.Client, - MeshId: meshId, - NodeID: m.HostParameters.GetPublicKey(), + DevName: ifName, + Port: args.Port, + Conf: meshConfiguration, + Client: m.Client, + MeshId: meshId, + DaemonConf: m.conf, + NodeID: m.HostParameters.GetPublicKey(), }) if err != nil { @@ -155,6 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { m.lock.Lock() m.Meshes[meshId] = nodeManager m.lock.Unlock() + + m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...) + return meshId, nil } @@ -162,6 +186,7 @@ type AddMeshParams struct { MeshId string WgPort int MeshBytes []byte + Conf *conf.WgConfiguration } // AddMesh: Add the mesh to the list of meshes @@ -169,6 +194,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { var ifName string var err error + meshConfiguration, err := m.getConf(params.Conf) + + if err != nil { + return err + } + + m.cmdRunner.RunCommands(meshConfiguration.PreUp...) + if !m.conf.StubWg { ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey) @@ -178,14 +211,17 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { } meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ - DevName: ifName, - Port: params.WgPort, - Conf: m.conf, - Client: m.Client, - MeshId: params.MeshId, - NodeID: m.HostParameters.GetPublicKey(), + DevName: ifName, + Port: params.WgPort, + Conf: meshConfiguration, + Client: m.Client, + MeshId: params.MeshId, + DaemonConf: m.conf, + NodeID: m.HostParameters.GetPublicKey(), }) + m.cmdRunner.RunCommands(meshConfiguration.PostUp...) + if err != nil { return err } @@ -255,10 +291,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { } node := s.nodeFactory.Build(&MeshNodeFactoryParams{ - PublicKey: &pubKey, - NodeIP: nodeIP, - WgPort: params.WgPort, - Endpoint: params.Endpoint, + PublicKey: &pubKey, + NodeIP: nodeIP, + WgPort: params.WgPort, + Endpoint: params.Endpoint, + MeshConfig: mesh.GetConfiguration(), }) if !s.conf.StubWg { @@ -301,6 +338,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { delete(s.Meshes, meshId) s.lock.Unlock() + s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...) + if !s.conf.StubWg { device, err := mesh.GetDevice() @@ -315,6 +354,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { } } + s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) + return err } @@ -436,7 +477,7 @@ func (s *MeshManagerImpl) Close() error { // NewMeshManagerParams params required to create an instance of a mesh manager type NewMeshManagerParams struct { - Conf conf.WgMeshConfiguration + Conf conf.DaemonConfiguration Client *wgctrl.Client MeshProvider MeshProviderFactory NodeFactory MeshNodeFactory @@ -445,6 +486,7 @@ type NewMeshManagerParams struct { InterfaceManipulator wg.WgInterfaceManipulator ConfigApplyer MeshConfigApplyer RouteManager RouteManager + CommandRunner cmd.CmdRunner OnDelete func(MeshProvider) } @@ -471,6 +513,10 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager { m.RouteManager = NewRouteManager(m, ¶ms.Conf) } + if params.CommandRunner == nil { + m.cmdRunner = &cmd.UnixCmdRunner{} + } + m.idGenerator = params.IdGenerator m.ipAllocator = params.IPAllocator m.interfaceManipulator = params.InterfaceManipulator diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index b1e551c..661c9e4 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -9,16 +9,9 @@ import ( "github.com/tim-beatham/wgmesh/pkg/wg" ) -func getMeshConfiguration() *conf.WgMeshConfiguration { - return &conf.WgMeshConfiguration{ - GrpcPort: "8080", - Endpoint: "abc.com", - ClusterSize: 64, - SyncRate: 4, - BranchRate: 3, - InterClusterChance: 0.15, - InfectionCount: 2, - KeepAliveTime: 60, +func getMeshConfiguration() *conf.DaemonConfiguration { + return &conf.DaemonConfiguration{ + GrpcPort: 8080, } } diff --git a/pkg/mesh/pruner.go b/pkg/mesh/pruner.go deleted file mode 100644 index 904bf22..0000000 --- a/pkg/mesh/pruner.go +++ /dev/null @@ -1,16 +0,0 @@ -package mesh - -import ( - "github.com/tim-beatham/wgmesh/pkg/conf" - "github.com/tim-beatham/wgmesh/pkg/lib" -) - -func pruneFunction(m MeshManager) lib.TimerFunc { - return func() error { - return m.Prune() - } -} - -func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer { - return lib.NewTimer(pruneFunction(m), conf.PruneTime/2) -} diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index 70ac341..39ef2bc 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -14,7 +14,7 @@ type RouteManager interface { type RouteManagerImpl struct { meshManager MeshManager - conf *conf.WgMeshConfiguration + conf *conf.DaemonConfiguration } func (r *RouteManagerImpl) UpdateRoutes() error { @@ -38,7 +38,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return err } - if r.conf.AdvertiseDefaultRoute { + if *r.conf.BaseConfiguration.AdvertiseDefaultRoute { _, ipv6Default, _ := net.ParseCIDR("::/0") mesh1.AddRoutes(NodeID(self), @@ -106,6 +106,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error { return nil } -func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager { +func NewRouteManager(m MeshManager, conf *conf.DaemonConfiguration) RouteManager { return &RouteManagerImpl{meshManager: m, conf: conf} } diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 96811c5..ca1efb9 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -81,6 +81,11 @@ type MeshProviderStub struct { snapshot *MeshSnapshotStub } +// GetConfiguration implements MeshProvider. +func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration { + panic("unimplemented") +} + // Mark implements MeshProvider. func (*MeshProviderStub) Mark(nodeId string) { panic("unimplemented") @@ -195,7 +200,7 @@ func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams) } type StubNodeFactory struct { - Config *conf.WgMeshConfiguration + Config *conf.DaemonConfiguration } func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode { @@ -274,7 +279,7 @@ func NewMeshManagerStub() MeshManager { return &MeshManagerStub{meshes: make(map[string]MeshProvider)} } -func (m *MeshManagerStub) CreateMesh(port int) (string, error) { +func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) { return "tim123", nil } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 16b9d9c..b04f453 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -145,6 +145,9 @@ type MeshProvider interface { // Mark: marks the node as unreachable. This is not broadcast to the entire // this is not considered when syncing node state Mark(nodeId string) + // GetConfiguration: gets the configuration parameters specific for this + // mesh network + GetConfiguration() *conf.WgConfiguration } // HostParameters contains the IDs of a node @@ -159,12 +162,13 @@ func (h *HostParameters) GetPublicKey() string { // MeshProviderFactoryParams parameters required to build a mesh provider type MeshProviderFactoryParams struct { - DevName string - MeshId string - Port int - Conf *conf.WgMeshConfiguration - Client *wgctrl.Client - NodeID string + DevName string + MeshId string + Port int + Conf *conf.WgConfiguration + DaemonConf *conf.DaemonConfiguration + Client *wgctrl.Client + NodeID string } // MeshProviderFactory creates an instance of a mesh provider @@ -175,10 +179,11 @@ type MeshProviderFactory interface { // MeshNodeFactoryParams are the parameters required to construct // a mesh node type MeshNodeFactoryParams struct { - PublicKey *wgtypes.Key - NodeIP net.IP - WgPort int - Endpoint string + PublicKey *wgtypes.Key + NodeIP net.IP + WgPort int + Endpoint string + MeshConfig *conf.WgConfiguration } // MeshBuilder build the hosts mesh node for it to be added to the mesh diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index f6ccdf4..6b624d6 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/lib" @@ -21,7 +22,25 @@ type IpcHandler struct { } func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { - meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort) + overrideConf := &conf.WgConfiguration{} + + if args.Role != "" { + role := conf.NodeType(args.Role) + overrideConf.Role = &role + } + + if args.Endpoint != "" { + overrideConf.Endpoint = &args.Endpoint + } + + if *overrideConf.Role == conf.CLIENT_ROLE { + return fmt.Errorf("cannot create a mesh with no public endpoint") + } + + meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{ + Port: args.WgPort, + Conf: overrideConf, + }) if err != nil { return err @@ -45,7 +64,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes())) i := 0 - for meshId, _ := range n.Server.GetMeshManager().GetMeshes() { + for meshId := range n.Server.GetMeshManager().GetMeshes() { meshNames[i] = meshId i++ } @@ -55,6 +74,17 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { } func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { + overrideConf := &conf.WgConfiguration{} + + if args.Role != "" { + role := conf.NodeType(args.Role) + overrideConf.Role = &role + } + + if args.Endpoint != "" { + overrideConf.Endpoint = &args.Endpoint + } + peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) if err != nil { @@ -88,6 +118,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { MeshId: args.MeshId, WgPort: args.Port, MeshBytes: meshReply.Mesh, + Conf: overrideConf, }) if err != nil { diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index d6e3db8..a74685d 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -24,7 +24,7 @@ type SyncerImpl struct { infectionCount int syncCount int cluster conn.ConnCluster - conf *conf.WgMeshConfiguration + conf *conf.DaemonConfiguration lastSync uint64 } @@ -33,7 +33,9 @@ func (s *SyncerImpl) Sync(meshId string) error { // Self can be nil if the node is removed self, _ := s.manager.GetSelf(meshId) - s.manager.GetMesh(meshId).Prune() + correspondingMesh := s.manager.GetMesh(meshId) + + correspondingMesh.Prune() if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { logging.Log.WriteInfof("No changes for %s", meshId) @@ -47,7 +49,7 @@ func (s *SyncerImpl) Sync(meshId string) error { logging.Log.WriteInfof(publicKey.String()) - nodeNames := s.manager.GetMesh(meshId).GetPeers() + nodeNames := correspondingMesh.GetPeers() if self != nil { nodeNames = lib.Filter(nodeNames, func(s string) bool { @@ -133,7 +135,7 @@ func (s *SyncerImpl) SyncMeshes() error { return nil } -func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer { +func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer { cluster, _ := conn.NewConnCluster(conf.ClusterSize) return &SyncerImpl{ manager: m, diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index 13ced7e..0e91f3b 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -91,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro c := rpc.NewSyncServiceClient(client) - syncTimeOut := s.server.Conf.SyncRate * float64(time.Second) + syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) defer cancel() diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 4be4e30..35c1c19 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -14,5 +14,5 @@ func syncFunction(syncer Syncer) lib.TimerFunc { } func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer { - return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) + return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate) } diff --git a/pkg/timers/timers.go b/pkg/timers/timers.go index bbc8430..e26a644 100644 --- a/pkg/timers/timers.go +++ b/pkg/timers/timers.go @@ -11,6 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { logging.Log.WriteInfof("Updated Timestamp") return ctrlServer.MeshManager.UpdateTimeStamp() } - return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) }