Merge pull request #54 from tim-beatham/53-run-commands-pre-up-and-post-down

53-run-commands-pre-up-and-post-down
This commit is contained in:
Tim Beatham 2023-12-10 19:22:59 +00:00 committed by GitHub
commit 27ec23f133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 466 additions and 474 deletions

View File

@ -4,13 +4,9 @@ import (
"fmt" "fmt"
ipcRpc "net/rpc" ipcRpc "net/rpc"
"os" "os"
"strings"
"time"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
@ -20,6 +16,7 @@ type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
WgPort int WgPort int
Endpoint string Endpoint string
Role string
} }
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
@ -27,6 +24,7 @@ func createMesh(args *CreateMeshParams) string {
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
Role: args.Role,
} }
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
@ -60,6 +58,7 @@ type JoinMeshParams struct {
IfName string IfName string
WgPort int WgPort int
Endpoint string Endpoint string
Role string
} }
func joinMesh(params *JoinMeshParams) string { func joinMesh(params *JoinMeshParams) string {
@ -69,6 +68,7 @@ func joinMesh(params *JoinMeshParams) string {
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
Port: params.WgPort, Port: params.WgPort,
Role: params.Role,
} }
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply) err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
@ -80,34 +80,6 @@ func joinMesh(params *JoinMeshParams) string {
return reply return reply
} }
func getMesh(client *ipcRpc.Client, meshId string) {
reply := new(ipc.GetMeshReply)
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
for _, node := range reply.Nodes {
fmt.Println("Public Key: " + node.PublicKey)
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
}
}
func leaveMesh(client *ipcRpc.Client, meshId string) { func leaveMesh(client *ipcRpc.Client, meshId string) {
var reply string var reply string
@ -255,11 +227,13 @@ func main() {
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -300,6 +274,7 @@ func main() {
Client: client, Client: client,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
})) }))
} }
@ -314,6 +289,7 @@ func main() {
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint, Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
})) }))
} }

View File

@ -19,13 +19,13 @@ import (
func main() { func main() {
if len(os.Args) != 2 { if len(os.Args) != 2 {
logging.Log.WriteErrorf("Need to provide configuration.yaml") logging.Log.WriteErrorf("Did not provide configuration")
return return
} }
conf, err := conf.ParseConfiguration(os.Args[1]) conf, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil { if err != nil {
logging.Log.WriteInfof("Could not parse configuration") logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return return
} }

View File

@ -24,7 +24,7 @@ type CrdtMeshManager struct {
Client *wgctrl.Client Client *wgctrl.Client
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration conf *conf.WgConfiguration
cache *MeshCrdt cache *MeshCrdt
lastCacheHash automerge.ChangeHash lastCacheHash automerge.ChangeHash
} }
@ -74,8 +74,8 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
return false return false
} }
keepAliveTime := timestamp.Int64() return true
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime) // return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
} }
func (c *CrdtMeshManager) GetPeers() []string { func (c *CrdtMeshManager) GetPeers() []string {
@ -135,7 +135,7 @@ type NewCrdtNodeMangerParams struct {
MeshId string MeshId string
DevName string DevName string
Port int Port int
Conf conf.WgMeshConfiguration Conf *conf.WgConfiguration
Client *wgctrl.Client Client *wgctrl.Client
} }
@ -146,7 +146,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.doc = automerge.New() manager.doc = automerge.New()
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = params.Conf
manager.cache = nil manager.cache = nil
return &manager, nil return &manager, nil
} }
@ -473,59 +473,20 @@ func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) erro
return err return err
} }
// GetConfiguration: gets the configuration for this mesh network
func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}
// Mark: mark the node as locally dead
func (m *CrdtMeshManager) Mark(nodeId string) {
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
func (m *CrdtMeshManager) Prune() error { func (m *CrdtMeshManager) Prune() error {
// nodes, err := m.doc.Path("nodes").Get()
// if err != nil {
// return err
// }
// if nodes.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
// values, err := nodes.Map().Values()
// if err != nil {
// return err
// }
// deletionNodes := make([]string, 0)
// for nodeId, node := range values {
// if node.Kind() != automerge.KindMap {
// return errors.New("node must be a map")
// }
// nodeMap := node.Map()
// timeStamp, err := nodeMap.Get("timestamp")
// if err != nil {
// return err
// }
// if timeStamp.Kind() != automerge.KindInt64 {
// return errors.New("timestamp is not int64")
// }
// timeValue := timeStamp.Int64()
// nowValue := time.Now().Unix()
// if nowValue-timeValue >= int64(pruneTime) {
// deletionNodes = append(deletionNodes, nodeId)
// }
// }
// for _, node := range deletionNodes {
// logging.Log.WriteInfof("Pruning %s", node)
// nodes.Map().Delete(node)
// }
return nil return nil
} }

View File

@ -22,7 +22,7 @@ func setUpTests() *TestParams {
DevName: "wg0", DevName: "wg0",
Port: 5000, Port: 5000,
Client: nil, Client: nil,
Conf: conf.WgMeshConfiguration{}, Conf: conf.DaemonConfiguration{},
}) })
return &TestParams{ return &TestParams{

View File

@ -14,13 +14,13 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{ return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId, MeshId: params.MeshId,
DevName: params.DevName, DevName: params.DevName,
Conf: *params.Conf, Conf: params.Conf,
Client: params.Client, Client: params.Client,
}) })
} }
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.WgMeshConfiguration Config conf.DaemonConfiguration
} }
// Build builds the mesh node that represents the host machine to add // Build builds the mesh node that represents the host machine to add
@ -30,7 +30,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE { if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-" grpcEndpoint = "-"
} }
@ -44,7 +44,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
Routes: make(map[string]Route), Routes: make(map[string]Route),
Description: "", Description: "",
Alias: "", Alias: "",
Type: string(f.Config.Role), Type: string(*params.MeshConfig.Role),
} }
} }
@ -54,12 +54,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
if params.Endpoint != "" { if params.Endpoint != "" {
hostName = params.Endpoint hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 { } else if len(*params.MeshConfig.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = *params.MeshConfig.Endpoint
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

33
pkg/cmd/cmd.go Normal file
View File

@ -0,0 +1,33 @@
// cmd is a package for running commands in the different operating systems implementations
package cmd
import (
"os/exec"
"strings"
)
type CmdRunner interface {
RunCommands(commands ...string) error
}
type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields
// and then runs the command accordingly
func RunCommand(cmd string) error {
args := strings.Fields(cmd)
c := exec.Command(args[0], args[1:]...)
return c.Run()
}
func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands {
err := RunCommand(cmd)
if err != nil {
return err
}
}
return nil
}

View File

@ -4,7 +4,7 @@ package conf
import ( import (
"os" "os"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -30,172 +30,187 @@ const (
DNS_IP_DISCOVERY = "dns" DNS_IP_DISCOVERY = "dns"
) )
type WgMeshConfiguration struct { // WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
// tell if the attribute is set
type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoutes"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
// for all nodes to route their packets to
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"`
// Endpoint contains what value should be set as the public endpoint of this node
Endpoint *string `yaml:"publicEndpoint"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"`
// PostUp are WireGuard commands to run after adding the WG interface
PostUp []string `yaml:"postUp"`
// PreDown are WireGuard commands to run prior to removing the WG interface
PreDown []string `yaml:"preDown"`
// PostDown are WireGuard command to run after removing the WG interface
PostDown []string `yaml:"postDown"`
}
type DaemonConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath" validate:"required,file"`
// PrivateKeypath is the path to the clients private key in mTLS // PrivateKeypath is the path to the clients private key in mTLS
PrivateKeyPath string `yaml:"privateKeyPath"` PrivateKeyPath string `yaml:"privateKeyPath" validate:"required,file"`
// CaCeritifcatePath path to the certificate of the trust certificate authority // CaCeritifcatePath path to the certificate of the trust certificate authority
CaCertificatePath string `yaml:"caCertificatePath"` CaCertificatePath string `yaml:"caCertificatePath" validate:"required,file"`
// SkipCertVerification specify to skip certificate verification. Should only be used // SkipCertVerification specify to skip certificate verification. Should only be used
// in test environments // in test environments
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on // Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"` GrpcPort int `yaml:"gRPCPort" validate:"required"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or // Timeout number of seconds without response that a node is considered unreachable by gRPC
// use public IP discovery library Timeout int `yaml:"timeout" validate:"required,gte=1"`
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// AdvertiseDefaultRoute advertises a default route out of the mesh.
AdvertiseDefaultRoute bool `yaml:"advertiseDefaults"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
// ClusterSize size of the cluster to split on
ClusterSize int `yaml:"clusterSize"`
// SyncRate number of times per second to perform a sync
SyncRate float64 `yaml:"syncRate"`
// InterClusterChance proability of inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance"`
// BranchRate number of nodes to randomly communicate with
BranchRate int `yaml:"branchRate"`
// InfectionCount number of times we sync before we can no longer catch the udpate
InfectionCount int `yaml:"infectionCount"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code // Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible. // SyncRate specifies how long the minimum time should be between synchronisation
// If the user is globaly accessible they specify themselves as a client. SyncRate int `yaml:"syncRate" validate:"required,gte=1"`
Role NodeType `yaml:"role"` // KeepAliveTime: number of seconds before the leader of the mesh sends an update to
// KeepAliveWg configures the implementation so that we send keep alive packets to peers. // send to every member in the mesh
// KeepAlive can only be set if role is type client KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"`
KeepAliveWg int `yaml:"keepAliveWg"` // ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" valdiate:"required,gt=0"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" valdiate:"required,gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { // ValdiateMeshConfiguration: validates the mesh configuration
if len(c.CertificatePath) == 0 { func ValidateMeshConfiguration(conf *WgConfiguration) error {
return &WgMeshConfigurationError{ validate := validator.New(validator.WithRequiredStructEnabled())
msg: "A public certificate must be specified for mTLS", err := validate.Struct(conf)
}
if conf.PostDown == nil {
conf.PostDown = make([]string, 0)
} }
if len(c.PrivateKeyPath) == 0 { if conf.PostUp == nil {
return &WgMeshConfigurationError{ conf.PostUp = make([]string, 0)
msg: "A private key must be specified for mTLS",
}
} }
if len(c.CaCertificatePath) == 0 { if conf.PreDown == nil {
return &WgMeshConfigurationError{ conf.PreDown = make([]string, 0)
msg: "A ca certificate must be specified for mTLS",
}
} }
if len(c.GrpcPort) == 0 { if conf.PreUp == nil {
return &WgMeshConfigurationError{ conf.PreUp = make([]string, 0)
msg: "A grpc port must be specified",
}
} }
if c.ClusterSize <= 0 { return err
return &WgMeshConfigurationError{
msg: "A cluster size must not be 0",
}
}
if c.SyncRate <= 0 {
return &WgMeshConfigurationError{
msg: "SyncRate cannot be negative",
}
}
if c.BranchRate <= 0 {
return &WgMeshConfigurationError{
msg: "Branch rate cannot be negative",
}
}
if c.InfectionCount <= 0 {
return &WgMeshConfigurationError{
msg: "Infection count cannot be less than 1",
}
}
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
}
if c.InterClusterChance <= 0 {
return &WgMeshConfigurationError{
msg: "Intercluster chance cannot be less than 0",
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime < 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil
} }
// ParseConfiguration parses the mesh configuration // ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
var conf WgMeshConfiguration validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c)
return err
}
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath) yamlBytes, err := os.ReadFile(filePath)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
return nil, err return nil, err
} }
err = yaml.Unmarshal(yamlBytes, &conf) err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil { if err != nil {
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
return nil, err return nil, err
} }
return &conf, ValidateConfiguration(&conf) return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateDaemonConfiguration(&conf)
}
// MergemeshConfiguration: merges the configuration in precedence where the last
// element in the list takes the most and the first takes the least
func MergeMeshConfiguration(cfgs ...WgConfiguration) (WgConfiguration, error) {
var result WgConfiguration
for _, cfg := range cfgs {
if cfg.AdvertiseDefaultRoute != nil {
result.AdvertiseDefaultRoute = cfg.AdvertiseDefaultRoute
}
if cfg.AdvertiseRoutes != nil {
result.AdvertiseRoutes = cfg.AdvertiseRoutes
}
if cfg.Endpoint != nil {
result.Endpoint = cfg.Endpoint
}
if cfg.IPDiscovery != nil {
result.IPDiscovery = cfg.IPDiscovery
}
if cfg.KeepAliveWg != nil {
result.KeepAliveWg = cfg.KeepAliveWg
}
if cfg.PostDown != nil {
result.PostDown = cfg.PostDown
}
if cfg.PostUp != nil {
result.PostUp = cfg.PostUp
}
if cfg.PreDown != nil {
result.PreDown = cfg.PreDown
}
if cfg.PreUp != nil {
result.PreUp = cfg.PreUp
}
if cfg.Role != nil {
result.Role = cfg.Role
}
}
return result, ValidateMeshConfiguration(&result)
} }

View File

@ -2,23 +2,12 @@ package conf
import "testing" import "testing"
func getExampleConfiguration() *WgMeshConfiguration { func getExampleConfiguration() *DaemonConfiguration {
return &WgMeshConfiguration{ return &DaemonConfiguration{
CertificatePath: "./cert/cert.pem", CertificatePath: "./cert/cert.pem",
PrivateKeyPath: "./cert/key.pem", PrivateKeyPath: "./cert/key.pem",
CaCertificatePath: "./cert/ca.pems", CaCertificatePath: "./cert/ca.pems",
SkipCertVerification: true, SkipCertVerification: true,
GrpcPort: "8080",
AdvertiseRoutes: true,
Endpoint: "localhost",
ClusterSize: 1,
SyncRate: 1,
InterClusterChance: 0.1,
BranchRate: 2,
KeepAliveTime: 4,
InfectionCount: 1,
Timeout: 2,
PruneTime: 20,
} }
} }
@ -26,7 +15,7 @@ func TestConfigurationCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.CertificatePath = "" conf.CertificatePath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -37,7 +26,7 @@ func TestConfigurationPrivateKeyPathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.PrivateKeyPath = "" conf.PrivateKeyPath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -48,7 +37,7 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.CaCertificatePath = "" conf.CaCertificatePath = ""
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
@ -57,109 +46,21 @@ func TestConfigurationCaCertificatePathEmpty(t *testing.T) {
func TestConfigurationGrpcPortEmpty(t *testing.T) { func TestConfigurationGrpcPortEmpty(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.GrpcPort = "" conf.GrpcPort = 0
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
t.Fatal(`error should be thrown`) t.Fatal(`error should be thrown`)
} }
} }
func TestClusterSizeZero(t *testing.T) { func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration()
conf.ClusterSize = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func SyncRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func BranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.BranchRate = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func InfectionCountZero(t *testing.T) {
conf := getExampleConfiguration()
conf.InfectionCount = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func KeepAliveRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidCOnfiguration(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
err := ValidateConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

View File

@ -2,6 +2,7 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
@ -21,13 +22,13 @@ type ConnectionServer struct {
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes // the sync service to synchronise nodes
syncProvider rpc.SyncServiceServer syncProvider rpc.SyncServiceServer
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
listener net.Listener listener net.Listener
} }
// NewConnectionServerParams contains params for creating a new connection server // NewConnectionServerParams contains params for creating a new connection server
type NewConnectionServerParams struct { type NewConnectionServerParams struct {
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
} }
@ -76,10 +77,10 @@ func (s *ConnectionServer) Listen() error {
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))
s.listener = lis s.listener = lis
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort) logging.Log.WriteInfof("GRPC listening on %d\n", s.Conf.GrpcPort)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())

View File

@ -154,12 +154,13 @@ func (m *MeshSnapshot) GetNodes() map[string]mesh.MeshNode {
} }
type TwoPhaseStoreMeshManager struct { type TwoPhaseStoreMeshManager struct {
MeshId string MeshId string
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
LastClock uint64 LastClock uint64
conf *conf.WgMeshConfiguration conf *conf.WgConfiguration
store *TwoPhaseMap[string, MeshNode] daemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode]
} }
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
@ -264,7 +265,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0] peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.conf.KeepAliveTime) { if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) {
m.store.Mark(peerToUpdate) m.store.Mark(peerToUpdate)
if len(peers) < 2 { if len(peers) < 2 {
@ -506,3 +507,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
m.store.Remove(nodeId) m.store.Remove(nodeId)
return nil return nil
} }
// GetConfiguration implements mesh.MeshProvider.
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf
}

View File

@ -9,44 +9,49 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
type TwoPhaseMapFactory struct{} type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration
}
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{ return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId, MeshId: params.MeshId,
IfName: params.DevName, IfName: params.DevName,
Client: params.Client, Client: params.Client,
conf: params.Conf, conf: params.Conf,
daemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a() h := fnv.New64a()
h.Write([]byte(s)) h.Write([]byte(s))
return h.Sum64() return h.Sum64()
}, uint64(3*params.Conf.KeepAliveTime)), }, uint64(3*f.Config.KeepAliveTime)),
}, nil }, nil
} }
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.WgMeshConfiguration Config conf.DaemonConfiguration
} }
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
wgEndpoint := fmt.Sprintf("%s:%d", hostName, params.WgPort)
if f.Config.Role == conf.CLIENT_ROLE { if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-" grpcEndpoint = "-"
wgEndpoint = "-"
} }
return &MeshNode{ return &MeshNode{
HostEndpoint: grpcEndpoint, HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(), PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgEndpoint: wgEndpoint,
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
Routes: make(map[string]Route), Routes: make(map[string]Route),
Description: "", Description: "",
Alias: "", Alias: "",
Type: string(f.Config.Role), Type: string(*params.MeshConfig.Role),
} }
} }
@ -56,12 +61,12 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
if params.Endpoint != "" { if params.Endpoint != "" {
hostName = params.Endpoint hostName = params.Endpoint
} else if len(f.Config.Endpoint) != 0 { } else if params.MeshConfig.Endpoint != nil && len(*params.MeshConfig.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = *params.MeshConfig.Endpoint
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

View File

@ -16,7 +16,7 @@ import (
// NewCtrlServerParams are the params requried to create a new ctrl server // NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct { type NewCtrlServerParams struct {
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
@ -28,7 +28,9 @@ type NewCtrlServerParams struct {
// operation failed // operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
meshFactory := &crdt.TwoPhaseMapFactory{} meshFactory := &crdt.TwoPhaseMapFactory{
Config: params.Conf,
}
nodeFactory := &crdt.MeshNodeFactory{ nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
@ -36,7 +38,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf) configApplyer := mesh.NewWgMeshConfigApplyer()
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,
@ -87,7 +89,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil return ctrlServer, nil
} }
func (s *MeshCtrlServer) GetConfiguration() *conf.WgMeshConfiguration { func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf return s.Conf
} }

View File

@ -34,7 +34,7 @@ type Mesh struct {
} }
type CtrlServer interface { type CtrlServer interface {
GetConfiguration() *conf.WgMeshConfiguration GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetQuerier() query.Querier GetQuerier() query.Querier
GetMeshManager() mesh.MeshManager GetMeshManager() mesh.MeshManager
@ -48,6 +48,6 @@ type MeshCtrlServer struct {
MeshManager mesh.MeshManager MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration Conf *conf.DaemonConfiguration
Querier query.Querier Querier query.Querier
} }

View File

@ -23,10 +23,10 @@ func NewCtrlServerStub() *CtrlServerStub {
} }
} }
func (c *CtrlServerStub) GetConfiguration() *conf.WgMeshConfiguration { func (c *CtrlServerStub) GetConfiguration() *conf.DaemonConfiguration {
return &conf.WgMeshConfiguration{ return &conf.DaemonConfiguration{
GrpcPort: "8080", GrpcPort: 8080,
Endpoint: "abc.com", BaseConfiguration: conf.WgConfiguration{},
} }
} }

View File

@ -16,6 +16,7 @@ type NewMeshArgs struct {
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
// or DNS entry // or DNS entry
Endpoint string Endpoint string
Role string
} }
type JoinMeshArgs struct { type JoinMeshArgs struct {
@ -25,12 +26,12 @@ type JoinMeshArgs struct {
IpAdress string IpAdress string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint to use to override the default
// defaults to the default address
Endpoint string Endpoint string
// Client specifies whether we should join as a client of the peer // Client specifies whether we should join as a client of the peer
// we are connecting to // we are connecting to
Client bool Client bool
Role string
} }
type PutServiceArgs struct { type PutServiceArgs struct {

View File

@ -24,7 +24,6 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int hashFunc func(MeshNode) int
} }
@ -34,28 +33,33 @@ type routeNode struct {
route Route route Route
} }
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode, type convertMeshNodeParams struct {
device *wgtypes.Device, node MeshNode
peerToClients map[string][]net.IPNet, self MeshNode
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) { mesh MeshProvider
device *wgtypes.Device
peerToClients map[string][]net.IPNet
routes map[string][]routeNode
}
pubKey, err := node.GetPublicKey() func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
pubKey, err := params.node.GetPublicKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
allowedips := make([]net.IPNet, 1) allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost() allowedips[0] = *params.node.GetWgHost()
clients, ok := peerToClients[pubKey.String()] clients, ok := params.peerToClients[pubKey.String()]
if ok { if ok {
allowedips = append(allowedips, clients...) allowedips = append(allowedips, clients...)
} }
for _, route := range node.GetRoutes() { for _, route := range params.node.GetRoutes() {
bestRoutes := routes[route.GetDestination().String()] bestRoutes := params.routes[route.GetDestination().String()]
var pickedRoute routeNode var pickedRoute routeNode
if len(bestRoutes) == 1 { if len(bestRoutes) == 1 {
@ -66,7 +70,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
} }
// Else there is more than one candidate so consistently hash // Else there is more than one candidate so consistently hash
pickedRoute = lib.ConsistentHash(bestRoutes, self, bucketFunc, m.hashFunc) pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
} }
if pickedRoute.gateway == pubKey.String() { if pickedRoute.gateway == pubKey.String() {
@ -74,14 +78,20 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
} }
} }
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second config := params.mesh.GetConfiguration()
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool { var keepAlive time.Duration = time.Duration(0)
pubKey, _ := node.GetPublicKey()
if config.KeepAliveWg != nil {
keepAlive = time.Duration(*config.KeepAliveWg) * time.Second
}
existing := slices.IndexFunc(params.device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := params.node.GetPublicKey()
return p.PublicKey.String() == pubKey.String() return p.PublicKey.String() == pubKey.String()
}) })
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,7 +99,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, self MeshNode,
// Don't override the existing IP in case it already exists // Don't override the existing IP in case it already exists
if existing != -1 { if existing != -1 {
endpoint = device.Peers[existing].Endpoint endpoint = params.device.Peers[existing].Endpoint
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
@ -127,7 +137,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
v6Default, _, _ := net.ParseCIDR("::/0") v6Default, _, _ := net.ParseCIDR("::/0")
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0") v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && m.config.AdvertiseDefaultRoute { if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
return true return true
} }
@ -230,7 +240,10 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
peer := m.getCorrespondingPeer(params.peers, self) peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey() pubKey, _ := peer.GetPublicKey()
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
config := params.mesh.GetConfiguration()
keepAlive := time.Duration(*config.KeepAliveWg) * time.Second
endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint()) endpoint, err := net.ResolveUDPAddr("udp", peer.GetWgEndpoint())
if err != nil { if err != nil {
@ -308,7 +321,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost()) peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) { if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) cfg, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
device: params.dev,
peerToClients: peerToClients,
routes: params.routes,
})
if err != nil { if err != nil {
return nil, err return nil, err
@ -325,7 +345,14 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
continue continue
} }
peer, err := m.convertMeshNode(n, self, params.dev, peerToClients, params.routes) peer, err := m.convertMeshNode(convertMeshNodeParams{
node: n,
self: self,
mesh: params.mesh,
peerToClients: peerToClients,
routes: params.routes,
device: params.dev,
})
if err != nil { if err != nil {
return nil, err return nil, err
@ -468,9 +495,8 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer { func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{ return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(), routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int { hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey() pubKey, _ := mn.GetPublicKey()

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/tim-beatham/wgmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -14,7 +15,7 @@ import (
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(port int) (string, error) CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
@ -30,7 +31,6 @@ type MeshManager interface {
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Prune() error
Close() error Close() error
GetMonitor() MeshMonitor GetMonitor() MeshMonitor
GetNode(string, string) MeshNode GetNode(string, string) MeshNode
@ -45,7 +45,7 @@ type MeshManagerImpl struct {
// HostParameters contains information that uniquely locates // HostParameters contains information that uniquely locates
// the node in the mesh network. // the node in the mesh network.
HostParameters *HostParameters HostParameters *HostParameters
conf *conf.WgMeshConfiguration conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory meshProviderFactory MeshProviderFactory
nodeFactory MeshNodeFactory nodeFactory MeshNodeFactory
configApplyer MeshConfigApplyer configApplyer MeshConfigApplyer
@ -53,6 +53,7 @@ type MeshManagerImpl struct {
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor Monitor MeshMonitor
cmdRunner cmd.CmdRunner
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
@ -108,21 +109,38 @@ func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor return m.Monitor
} }
// Prune implements MeshManager. // CreateMeshParams contains the parameters required to create a mesh
func (m *MeshManagerImpl) Prune() error { type CreateMeshParams struct {
for _, mesh := range m.Meshes { Port int
err := mesh.Prune() Conf *conf.WgConfiguration
}
// getConf: gets the new configuration with the base configuration overriden
// from the recent
func (m *MeshManagerImpl) getConf(override *conf.WgConfiguration) (*conf.WgConfiguration, error) {
meshConfiguration := m.conf.BaseConfiguration
if override != nil {
newConf, err := conf.MergeMeshConfiguration(meshConfiguration, *override)
if err != nil { if err != nil {
return err return nil, err
} }
meshConfiguration = newConf
} }
return nil return &meshConfiguration, nil
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
meshConfiguration, err := m.getConf(args.Conf)
if err != nil {
return "", err
}
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = "" var ifName string = ""
@ -131,8 +149,10 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
return "", err return "", err
} }
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PreUp...)
if !m.conf.StubWg { if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey) ifName, err = m.interfaceManipulator.CreateInterface(args.Port, m.HostParameters.PrivateKey)
if err != nil { if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
@ -140,12 +160,13 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
} }
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName, DevName: ifName,
Port: port, Port: args.Port,
Conf: m.conf, Conf: meshConfiguration,
Client: m.Client, Client: m.Client,
MeshId: meshId, MeshId: meshId,
NodeID: m.HostParameters.GetPublicKey(), DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
}) })
if err != nil { if err != nil {
@ -155,6 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
m.lock.Lock() m.lock.Lock()
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
m.lock.Unlock() m.lock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
return meshId, nil return meshId, nil
} }
@ -162,6 +186,7 @@ type AddMeshParams struct {
MeshId string MeshId string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
Conf *conf.WgConfiguration
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
@ -169,6 +194,14 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string var ifName string
var err error var err error
meshConfiguration, err := m.getConf(params.Conf)
if err != nil {
return err
}
m.cmdRunner.RunCommands(meshConfiguration.PreUp...)
if !m.conf.StubWg { if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey) ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
@ -178,14 +211,17 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
} }
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: ifName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: meshConfiguration,
Client: m.Client, Client: m.Client,
MeshId: params.MeshId, MeshId: params.MeshId,
NodeID: m.HostParameters.GetPublicKey(), DaemonConf: m.conf,
NodeID: m.HostParameters.GetPublicKey(),
}) })
m.cmdRunner.RunCommands(meshConfiguration.PostUp...)
if err != nil { if err != nil {
return err return err
} }
@ -255,10 +291,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: &pubKey, PublicKey: &pubKey,
NodeIP: nodeIP, NodeIP: nodeIP,
WgPort: params.WgPort, WgPort: params.WgPort,
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
MeshConfig: mesh.GetConfiguration(),
}) })
if !s.conf.StubWg { if !s.conf.StubWg {
@ -301,6 +338,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
s.lock.Unlock() s.lock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
if !s.conf.StubWg { if !s.conf.StubWg {
device, err := mesh.GetDevice() device, err := mesh.GetDevice()
@ -315,6 +354,8 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
} }
} }
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err return err
} }
@ -436,7 +477,7 @@ func (s *MeshManagerImpl) Close() error {
// NewMeshManagerParams params required to create an instance of a mesh manager // NewMeshManagerParams params required to create an instance of a mesh manager
type NewMeshManagerParams struct { type NewMeshManagerParams struct {
Conf conf.WgMeshConfiguration Conf conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
MeshProvider MeshProviderFactory MeshProvider MeshProviderFactory
NodeFactory MeshNodeFactory NodeFactory MeshNodeFactory
@ -445,6 +486,7 @@ type NewMeshManagerParams struct {
InterfaceManipulator wg.WgInterfaceManipulator InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer ConfigApplyer MeshConfigApplyer
RouteManager RouteManager RouteManager RouteManager
CommandRunner cmd.CmdRunner
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
@ -471,6 +513,10 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.RouteManager = NewRouteManager(m, &params.Conf) m.RouteManager = NewRouteManager(m, &params.Conf)
} }
if params.CommandRunner == nil {
m.cmdRunner = &cmd.UnixCmdRunner{}
}
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator

View File

@ -9,16 +9,9 @@ import (
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/wgmesh/pkg/wg"
) )
func getMeshConfiguration() *conf.WgMeshConfiguration { func getMeshConfiguration() *conf.DaemonConfiguration {
return &conf.WgMeshConfiguration{ return &conf.DaemonConfiguration{
GrpcPort: "8080", GrpcPort: 8080,
Endpoint: "abc.com",
ClusterSize: 64,
SyncRate: 4,
BranchRate: 3,
InterClusterChance: 0.15,
InfectionCount: 2,
KeepAliveTime: 60,
} }
} }

View File

@ -1,16 +0,0 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -14,7 +14,7 @@ type RouteManager interface {
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
conf *conf.WgMeshConfiguration conf *conf.DaemonConfiguration
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
@ -38,7 +38,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err return err
} }
if r.conf.AdvertiseDefaultRoute { if *r.conf.BaseConfiguration.AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0") _, ipv6Default, _ := net.ParseCIDR("::/0")
mesh1.AddRoutes(NodeID(self), mesh1.AddRoutes(NodeID(self),
@ -106,6 +106,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return nil return nil
} }
func NewRouteManager(m MeshManager, conf *conf.WgMeshConfiguration) RouteManager { func NewRouteManager(m MeshManager, conf *conf.DaemonConfiguration) RouteManager {
return &RouteManagerImpl{meshManager: m, conf: conf} return &RouteManagerImpl{meshManager: m, conf: conf}
} }

View File

@ -81,6 +81,11 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
// GetConfiguration implements MeshProvider.
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
panic("unimplemented")
}
// Mark implements MeshProvider. // Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) { func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented") panic("unimplemented")
@ -195,7 +200,7 @@ func (s *StubMeshProviderFactory) CreateMesh(params *MeshProviderFactoryParams)
} }
type StubNodeFactory struct { type StubNodeFactory struct {
Config *conf.WgMeshConfiguration Config *conf.DaemonConfiguration
} }
func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode { func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
@ -274,7 +279,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(port int) (string, error) { func (m *MeshManagerStub) CreateMesh(*CreateMeshParams) (string, error) {
return "tim123", nil return "tim123", nil
} }

View File

@ -145,6 +145,9 @@ type MeshProvider interface {
// Mark: marks the node as unreachable. This is not broadcast to the entire // Mark: marks the node as unreachable. This is not broadcast to the entire
// this is not considered when syncing node state // this is not considered when syncing node state
Mark(nodeId string) Mark(nodeId string)
// GetConfiguration: gets the configuration parameters specific for this
// mesh network
GetConfiguration() *conf.WgConfiguration
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
@ -159,12 +162,13 @@ func (h *HostParameters) GetPublicKey() string {
// MeshProviderFactoryParams parameters required to build a mesh provider // MeshProviderFactoryParams parameters required to build a mesh provider
type MeshProviderFactoryParams struct { type MeshProviderFactoryParams struct {
DevName string DevName string
MeshId string MeshId string
Port int Port int
Conf *conf.WgMeshConfiguration Conf *conf.WgConfiguration
Client *wgctrl.Client DaemonConf *conf.DaemonConfiguration
NodeID string Client *wgctrl.Client
NodeID string
} }
// MeshProviderFactory creates an instance of a mesh provider // MeshProviderFactory creates an instance of a mesh provider
@ -175,10 +179,11 @@ type MeshProviderFactory interface {
// MeshNodeFactoryParams are the parameters required to construct // MeshNodeFactoryParams are the parameters required to construct
// a mesh node // a mesh node
type MeshNodeFactoryParams struct { type MeshNodeFactoryParams struct {
PublicKey *wgtypes.Key PublicKey *wgtypes.Key
NodeIP net.IP NodeIP net.IP
WgPort int WgPort int
Endpoint string Endpoint string
MeshConfig *conf.WgConfiguration
} }
// MeshBuilder build the hosts mesh node for it to be added to the mesh // MeshBuilder build the hosts mesh node for it to be added to the mesh

View File

@ -8,6 +8,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -21,7 +22,25 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort) overrideConf := &conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgPort,
Conf: overrideConf,
})
if err != nil { if err != nil {
return err return err
@ -45,7 +64,7 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes())) meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0 i := 0
for meshId, _ := range n.Server.GetMeshManager().GetMeshes() { for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId meshNames[i] = meshId
i++ i++
} }
@ -55,6 +74,17 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
} }
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
overrideConf := &conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if err != nil { if err != nil {
@ -88,6 +118,7 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
MeshId: args.MeshId, MeshId: args.MeshId,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
Conf: overrideConf,
}) })
if err != nil { if err != nil {

View File

@ -24,7 +24,7 @@ type SyncerImpl struct {
infectionCount int infectionCount int
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.WgMeshConfiguration conf *conf.DaemonConfiguration
lastSync uint64 lastSync uint64
} }
@ -33,7 +33,9 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Self can be nil if the node is removed // Self can be nil if the node is removed
self, _ := s.manager.GetSelf(meshId) self, _ := s.manager.GetSelf(meshId)
s.manager.GetMesh(meshId).Prune() correspondingMesh := s.manager.GetMesh(meshId)
correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("No changes for %s", meshId)
@ -47,7 +49,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof(publicKey.String()) logging.Log.WriteInfof(publicKey.String())
nodeNames := s.manager.GetMesh(meshId).GetPeers() nodeNames := correspondingMesh.GetPeers()
if self != nil { if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool { nodeNames = lib.Filter(nodeNames, func(s string) bool {
@ -133,7 +135,7 @@ func (s *SyncerImpl) SyncMeshes() error {
return nil return nil
} }
func NewSyncer(m mesh.MeshManager, conf *conf.WgMeshConfiguration, r SyncRequester) Syncer { func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer {
cluster, _ := conn.NewConnCluster(conf.ClusterSize) cluster, _ := conn.NewConnCluster(conf.ClusterSize)
return &SyncerImpl{ return &SyncerImpl{
manager: m, manager: m,

View File

@ -91,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second) syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel() defer cancel()

View File

@ -14,5 +14,5 @@ func syncFunction(syncer Syncer) lib.TimerFunc {
} }
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer { func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate)) return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
} }

View File

@ -11,6 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
logging.Log.WriteInfof("Updated Timestamp") logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp() return ctrlServer.MeshManager.UpdateTimeStamp()
} }
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
} }