diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 7f879f7..ab759d5 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -22,25 +22,22 @@ type CreateMeshParams struct { AdvertiseDefault bool } -func createMesh(params *CreateMeshParams) string { +func createMesh(client *ipc.ClientIpc, args *ipc.NewMeshArgs) { var reply string - newMeshParams := ipc.NewMeshArgs{ - WgArgs: params.WgArgs, - } - - err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) + err := client.CreateMesh(args, &reply) if err != nil { - return err.Error() + fmt.Println(err.Error()) + return } - return reply + fmt.Println(reply) } -func listMeshes(client *ipcRpc.Client) { +func listMeshes(client *ipc.ClientIpc) { reply := new(ipc.ListMeshReply) - err := client.Call("IpcHandler.ListMeshes", "", &reply) + err := client.ListMeshes(reply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -52,38 +49,22 @@ func listMeshes(client *ipcRpc.Client) { } } -type JoinMeshParams struct { - Client *ipcRpc.Client - MeshId string - IpAddress string - Endpoint string - WgArgs ipc.WireGuardArgs - AdvertiseRoutes bool - AdvertiseDefault bool -} - -func joinMesh(params *JoinMeshParams) string { +func joinMesh(client *ipc.ClientIpc, args ipc.JoinMeshArgs) { var reply string - args := ipc.JoinMeshArgs{ - MeshId: params.MeshId, - IpAdress: params.IpAddress, - WgArgs: params.WgArgs, - } - - err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply) + err := client.JoinMesh(args, &reply) if err != nil { - return err.Error() + fmt.Println(err.Error()) } - return reply + fmt.Println(reply) } -func leaveMesh(client *ipcRpc.Client, meshId string) { +func leaveMesh(client *ipc.ClientIpc, meshId string) { var reply string - err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply) + err := client.LeaveMesh(meshId, &reply) if err != nil { fmt.Println(err.Error()) @@ -93,10 +74,10 @@ func leaveMesh(client *ipcRpc.Client, meshId string) { fmt.Println(reply) } -func getGraph(client *ipcRpc.Client) { +func getGraph(client *ipc.ClientIpc) { listMeshesReply := new(ipc.ListMeshReply) - err := client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) + err := client.ListMeshes(listMeshesReply) if err != nil { fmt.Println(err.Error()) @@ -108,7 +89,7 @@ func getGraph(client *ipcRpc.Client) { for _, meshId := range listMeshesReply.Meshes { var meshReply ipc.GetMeshReply - err := client.Call("IpcHandler.GetMesh", &meshId, &meshReply) + err := client.GetMesh(meshId, &meshReply) if err != nil { fmt.Println(err.Error()) @@ -129,10 +110,15 @@ func getGraph(client *ipcRpc.Client) { fmt.Println(dot) } -func queryMesh(client *ipcRpc.Client, meshId, query string) { +func queryMesh(client *ipc.ClientIpc, meshId, query string) { var reply string - err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply) + args := ipc.QueryMesh{ + MeshId: meshId, + Query: query, + } + + err := client.Query(args, &reply) if err != nil { fmt.Println(err.Error()) @@ -142,11 +128,13 @@ func queryMesh(client *ipcRpc.Client, meshId, query string) { fmt.Println(reply) } -// putDescription: puts updates the description about the node to the meshes -func putDescription(client *ipcRpc.Client, description string) { +func putDescription(client *ipc.ClientIpc, meshId, description string) { var reply string - err := client.Call("IpcHandler.PutDescription", &description, &reply) + err := client.PutDescription(ipc.PutDescriptionArgs{ + MeshId: meshId, + Description: description, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -157,10 +145,13 @@ func putDescription(client *ipcRpc.Client, description string) { } // putAlias: puts an alias for the node -func putAlias(client *ipcRpc.Client, alias string) { +func putAlias(client *ipc.ClientIpc, meshid, alias string) { var reply string - err := client.Call("IpcHandler.PutAlias", &alias, &reply) + err := client.PutAlias(ipc.PutAliasArgs{ + MeshId: meshid, + Alias: alias, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -170,15 +161,14 @@ func putAlias(client *ipcRpc.Client, alias string) { fmt.Println(reply) } -func setService(client *ipcRpc.Client, service, value string) { +func setService(client *ipc.ClientIpc, meshId, service, value string) { var reply string - serviceArgs := &ipc.PutServiceArgs{ + err := client.PutService(ipc.PutServiceArgs{ + MeshId: meshId, Service: service, Value: value, - } - - err := client.Call("IpcHandler.PutService", serviceArgs, &reply) + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -188,10 +178,13 @@ func setService(client *ipcRpc.Client, service, value string) { fmt.Println(reply) } -func deleteService(client *ipcRpc.Client, service string) { +func deleteService(client *ipc.ClientIpc, meshId, service string) { var reply string - err := client.Call("IpcHandler.PutService", &service, &reply) + err := client.DeleteService(ipc.DeleteServiceArgs{ + MeshId: meshId, + Service: service, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -226,7 +219,6 @@ func main() { }) var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ - Default: "peer", Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " protocol", @@ -259,7 +251,6 @@ func main() { }) var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ - Default: "peer", Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " protocol", @@ -302,6 +293,16 @@ func main() { Help: "Description of the node in the mesh", }) + var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + + var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{ Required: true, Help: "Alias of the node to set can be used in DNS to lookup an IP address", @@ -316,11 +317,21 @@ func main() { Help: "Value of the service to advertise in the mesh network", }) + var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{ Required: true, Help: "Key of the service to remove", }) + var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + err := parser.Parse(os.Args) if err != nil { @@ -328,16 +339,13 @@ func main() { return } - client, err := ipcRpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { - fmt.Println(err.Error()) - return + panic(err) } if newMeshCmd.Happened() { - fmt.Println(createMesh(&CreateMeshParams{ - Client: client, - Endpoint: *newMeshEndpoint, + args := &ipc.NewMeshArgs{ WgArgs: ipc.WireGuardArgs{ Endpoint: *newMeshEndpoint, Role: *newMeshRole, @@ -346,7 +354,9 @@ func main() { AdvertiseDefaultRoute: *newMeshAdvertiseDefaults, AdvertiseRoutes: *newMeshAdvertiseRoutes, }, - })) + } + + createMesh(client, args) } if listMeshCmd.Happened() { @@ -354,11 +364,9 @@ func main() { } if joinMeshCmd.Happened() { - fmt.Println(joinMesh(&JoinMeshParams{ - Client: client, + args := ipc.JoinMeshArgs{ IpAddress: *joinMeshIpAddress, MeshId: *joinMeshId, - Endpoint: *joinMeshEndpoint, WgArgs: ipc.WireGuardArgs{ Endpoint: *joinMeshEndpoint, Role: *joinMeshRole, @@ -367,7 +375,8 @@ func main() { AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults, AdvertiseRoutes: *joinMeshAdvertiseRoutes, }, - })) + } + joinMesh(client, args) } if getGraphCmd.Happened() { @@ -383,18 +392,18 @@ func main() { } if putDescriptionCmd.Happened() { - putDescription(client, *description) + putDescription(client, *descriptionMeshId, *description) } if putAliasCmd.Happened() { - putAlias(client, *alias) + putAlias(client, *aliasMeshId, *alias) } if setServiceCmd.Happened() { - setService(client, *serviceKey, *serviceValue) + setService(client, *serviceMeshId, *serviceKey, *serviceValue) } if deleteServiceCmd.Happened() { - deleteService(client, *deleteServiceKey) + deleteService(client, *deleteServiceMeshid, *deleteServiceKey) } } diff --git a/cmd/wgmeshd/configuration.yaml b/cmd/wgmeshd/configuration.yaml index 01bb357..49aac29 100644 --- a/cmd/wgmeshd/configuration.yaml +++ b/cmd/wgmeshd/configuration.yaml @@ -10,5 +10,5 @@ syncRate: 1 interClusterChance: 0.15 branchRate: 3 infectionCount: 3 -keepAliveTime: 10 +heartBeatTime: 10 pruneTime: 20 diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index 95aa9b3..226b2d7 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -59,6 +59,11 @@ func main() { } ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) + + if err != nil { + panic(err) + } + syncProvider.Server = ctrlServer syncRequester = sync.NewSyncRequester(ctrlServer) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) diff --git a/examples/meshtomesh/shared/configuration.yaml b/examples/meshtomesh/shared/configuration.yaml index 01bb357..49aac29 100644 --- a/examples/meshtomesh/shared/configuration.yaml +++ b/examples/meshtomesh/shared/configuration.yaml @@ -10,5 +10,5 @@ syncRate: 1 interClusterChance: 0.15 branchRate: 3 infectionCount: 3 -keepAliveTime: 10 +heartBeatTime: 10 pruneTime: 20 diff --git a/examples/simple/shared/configuration.yaml b/examples/simple/shared/configuration.yaml index 01bb357..49aac29 100644 --- a/examples/simple/shared/configuration.yaml +++ b/examples/simple/shared/configuration.yaml @@ -10,5 +10,5 @@ syncRate: 1 interClusterChance: 0.15 branchRate: 3 infectionCount: 3 -keepAliveTime: 10 +heartBeatTime: 10 pruneTime: 20 diff --git a/pkg/api/apiserver.go b/pkg/api/apiserver.go index d11fc56..46f665d 100644 --- a/pkg/api/apiserver.go +++ b/pkg/api/apiserver.go @@ -4,8 +4,6 @@ import ( "fmt" "net/http" - ipcRpc "net/rpc" - "github.com/gin-gonic/gin" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -13,8 +11,6 @@ import ( "github.com/tim-beatham/wgmesh/pkg/what8words" ) -const SockAddr = "/tmp/wgmesh_ipc.sock" - type ApiServer interface { GetMeshes(c *gin.Context) Run(addr string) error @@ -22,7 +18,7 @@ type ApiServer interface { type SmegServer struct { router *gin.Engine - client *ipcRpc.Client + client *ipc.ClientIpc words *what8words.What8Words } @@ -106,7 +102,7 @@ func (s *SmegServer) CreateMesh(c *gin.Context) { var reply string - err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply) + err := s.client.CreateMesh(&ipcRequest, &reply) if err != nil { c.JSON(http.StatusBadRequest, &gin.H{ @@ -132,8 +128,8 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { } ipcRequest := ipc.JoinMeshArgs{ - MeshId: joinMesh.MeshId, - IpAdress: joinMesh.Bootstrap, + MeshId: joinMesh.MeshId, + IpAddress: joinMesh.Bootstrap, WgArgs: ipc.WireGuardArgs{ WgPort: joinMesh.WgPort, }, @@ -141,7 +137,7 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { var reply string - err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply) + err := s.client.JoinMesh(ipcRequest, &reply) if err != nil { c.JSON(http.StatusBadRequest, &gin.H{ @@ -164,7 +160,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) { getMeshReply := new(ipc.GetMeshReply) - err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply) + err := s.client.GetMesh(meshid, getMeshReply) if err != nil { c.JSON(http.StatusNotFound, @@ -182,7 +178,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) { func (s *SmegServer) GetMeshes(c *gin.Context) { listMeshesReply := new(ipc.ListMeshReply) - err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) + err := s.client.ListMeshes(listMeshesReply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -195,7 +191,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) { for _, mesh := range listMeshesReply.Meshes { getMeshReply := new(ipc.GetMeshReply) - err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply) + err := s.client.GetMesh(mesh, getMeshReply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -215,7 +211,7 @@ func (s *SmegServer) Run(addr string) error { } func NewSmegServer(conf ApiServerConf) (ApiServer, error) { - client, err := ipcRpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { return nil, err diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 86ca077..e0bf283 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -47,7 +47,6 @@ type WgConfiguration struct { // If the user is globaly accessible they specify themselves as a client. Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"` // KeepAliveWg configures the implementation so that we send keep alive packets to peers. - // KeepAlive can only be set if role is type client KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"` // PreUp are WireGuard commands to run before adding the WG interface PreUp []string `yaml:"preUp"` @@ -77,11 +76,13 @@ type DaemonConfiguration struct { Profile bool `yaml:"profile"` // StubWg whether or not to stub the WireGuard types StubWg bool `yaml:"stubWg"` - // SyncRate specifies how long the minimum time should be between synchronisation - SyncRate int `yaml:"syncRate" validate:"required,gte=1"` - // KeepAliveTime: number of seconds before the leader of the mesh sends an update to + // SyncTime specifies how long the minimum time should be between synchronisation + SyncTime int `yaml:"syncTime" validate:"required,gte=1"` + // PullTime specifies the interval between checking for configuration changes + PullTime int `yaml:"pullTime" validate:"gte=0"` + // HeartBeat: number of seconds before the leader of the mesh sends an update to // send to every member in the mesh - KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"` + HeartBeat int `yaml:"heartBeatTime" validate:"required,gte=1"` // ClusterSize specifies how many neighbours you should synchronise with per round ClusterSize int `yaml:"clusterSize" validate:"gte=1"` // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index 45c8138..4189470 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -21,11 +21,12 @@ func getExampleConfiguration() *DaemonConfiguration { Timeout: 5, Profile: false, StubWg: false, - SyncRate: 2, - KeepAliveTime: 2, + SyncTime: 2, + HeartBeat: 2, ClusterSize: 64, InterClusterChance: 0.15, BranchRate: 3, + PullTime: 0, InfectionCount: 2, BaseConfiguration: WgConfiguration{ IPDiscovery: &discovery, @@ -162,9 +163,9 @@ func TestBranchRateZero(t *testing.T) { } } -func TestSyncRateZero(t *testing.T) { +func TestsyncTimeZero(t *testing.T) { conf := getExampleConfiguration() - conf.SyncRate = 0 + conf.SyncTime = 0 err := ValidateDaemonConfiguration(conf) @@ -175,7 +176,7 @@ func TestSyncRateZero(t *testing.T) { func TestKeepAliveTimeZero(t *testing.T) { conf := getExampleConfiguration() - conf.KeepAliveTime = 0 + conf.HeartBeat = 0 err := ValidateDaemonConfiguration(conf) if err == nil { @@ -215,6 +216,17 @@ func TestInfectionCountOne(t *testing.T) { } } +func TestPullTimeNegative(t *testing.T) { + conf := getExampleConfiguration() + conf.PullTime = -1 + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + func TestValidConfiguration(t *testing.T) { conf := getExampleConfiguration() err := ValidateDaemonConfiguration(conf) diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index 229f5ed..d5ef623 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -264,7 +264,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { peerToUpdate := peers[0] - if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.KeepAliveTime) { + if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.HeartBeat) { m.store.Mark(peerToUpdate) if len(peers) < 2 { diff --git a/pkg/crdt/datastore_test.go b/pkg/crdt/datastore_test.go index 4b8fa87..c86ec81 100644 --- a/pkg/crdt/datastore_test.go +++ b/pkg/crdt/datastore_test.go @@ -32,8 +32,8 @@ func setUpTests() *TestParams { GrpcPort: 0, Timeout: 20, Profile: false, - SyncRate: 2, - KeepAliveTime: 10, + SyncTime: 2, + HeartBeat: 10, ClusterSize: 32, InterClusterChance: 0.15, BranchRate: 3, diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go index d2d782c..9cefe05 100644 --- a/pkg/crdt/factory.go +++ b/pkg/crdt/factory.go @@ -24,7 +24,7 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) h := fnv.New64a() h.Write([]byte(s)) return h.Sum64() - }, uint64(3*f.Config.KeepAliveTime)), + }, uint64(3*f.Config.HeartBeat)), }, nil } diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 8847ee0..2ebfb5a 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "net" - "net/rpc" "github.com/miekg/dns" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -18,7 +17,7 @@ const SockAddr = "/tmp/wgmesh_ipc.sock" const MeshRegularExpression = `(?P.+)\.(?P.+)\.smeg\.` type DNSHandler struct { - client *rpc.Client + client *ipc.ClientIpc server *dns.Server } @@ -27,7 +26,7 @@ type DNSHandler struct { func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { var reply string - err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ + err := d.client.Query(ipc.QueryMesh{ MeshId: meshId, Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), }, &reply) @@ -97,7 +96,7 @@ func (h *DNSHandler) Close() error { } func NewDns(udpPort int) (*DNSHandler, error) { - client, err := rpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { return nil, err diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index b8896ee..895af93 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -5,11 +5,27 @@ import ( "net" "net/http" "net/rpc" + ipcRpc "net/rpc" "os" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ) +const SockAddr = "/tmp/wgmesh_sock" + +type MeshIpc interface { + CreateMesh(args *NewMeshArgs, reply *string) error + ListMeshes(name string, reply *ListMeshReply) error + JoinMesh(args *JoinMeshArgs, reply *string) error + LeaveMesh(meshId string, reply *string) error + GetMesh(meshId string, reply *GetMeshReply) error + Query(query QueryMesh, reply *string) error + PutDescription(args PutDescriptionArgs, reply *string) error + PutAlias(args PutAliasArgs, reply *string) error + PutService(args PutServiceArgs, reply *string) error + DeleteService(args DeleteServiceArgs, reply *string) error +} + // WireGuardArgs are provided args specific to WireGuard type WireGuardArgs struct { // WgPort is the WireGuard port to expose @@ -39,7 +55,7 @@ type JoinMeshArgs struct { // MeshId is the ID of the mesh to join MeshId string // IpAddress is a routable IP in another mesh - IpAdress string + IpAddress string // WgArgs is the WireGuard parameters to use. WgArgs WireGuardArgs } @@ -47,6 +63,22 @@ type JoinMeshArgs struct { type PutServiceArgs struct { Service string Value string + MeshId string +} + +type DeleteServiceArgs struct { + Service string + MeshId string +} + +type PutAliasArgs struct { + Alias string + MeshId string +} + +type PutDescriptionArgs struct { + Description string + MeshId string } type GetMeshReply struct { @@ -62,20 +94,65 @@ type QueryMesh struct { Query string } -type MeshIpc interface { - CreateMesh(args *NewMeshArgs, reply *string) error - ListMeshes(name string, reply *ListMeshReply) error - JoinMesh(args JoinMeshArgs, reply *string) error - LeaveMesh(meshId string, reply *string) error - GetMesh(meshId string, reply *GetMeshReply) error - Query(query QueryMesh, reply *string) error - PutDescription(description string, reply *string) error - PutAlias(alias string, reply *string) error - PutService(args PutServiceArgs, reply *string) error - DeleteService(service string, reply *string) error +type ClientIpc struct { + client *ipcRpc.Client } -const SockAddr = "/tmp/wgmesh_ipc.sock" +func NewClientIpc() (*ClientIpc, error) { + client, err := ipcRpc.DialHTTP("unix", SockAddr) + + if err != nil { + return nil, err + } + + return &ClientIpc{ + client: client, + }, nil +} + +func (c *ClientIpc) CreateMesh(args *NewMeshArgs, reply *string) error { + return c.client.Call("IpcHandler.CreateMesh", args, reply) +} + +func (c *ClientIpc) ListMeshes(reply *ListMeshReply) error { + return c.client.Call("IpcHandler.ListMeshes", "", reply) +} + +func (c *ClientIpc) JoinMesh(args JoinMeshArgs, reply *string) error { + return c.client.Call("IpcHandler.JoinMesh", &args, reply) +} + +func (c *ClientIpc) LeaveMesh(meshId string, reply *string) error { + return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply) +} + +func (c *ClientIpc) GetMesh(meshId string, reply *GetMeshReply) error { + return c.client.Call("IpcHandler.GetMesh", &meshId, reply) +} + +func (c *ClientIpc) Query(query QueryMesh, reply *string) error { + return c.client.Call("IpcHandler.Query", &query, reply) +} + +func (c *ClientIpc) PutDescription(args PutDescriptionArgs, reply *string) error { + return c.client.Call("IpcHandler.PutDescription", &args, reply) +} + +func (c *ClientIpc) PutAlias(args PutAliasArgs, reply *string) error { + return c.client.Call("IpcHandler.PutAlias", &args, reply) +} + +func (c *ClientIpc) PutService(args PutServiceArgs, reply *string) error { + return c.client.Call("IpcHandler.PutService", &args, reply) +} + +func (c *ClientIpc) DeleteService(args DeleteServiceArgs, reply *string) error { + return c.client.Call("IpcHandler.DeleteService", &args, reply) +} + +func (c *ClientIpc) Close() error { + return c.Close() +} func RunIpcHandler(server MeshIpc) error { if err := os.RemoveAll(SockAddr); err != nil { diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 0069b03..efd60c3 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -24,10 +24,10 @@ type MeshManager interface { LeaveMesh(meshId string) error GetSelf(meshId string) (MeshNode, error) ApplyConfig() error - SetDescription(description string) error - SetAlias(alias string) error - SetService(service string, value string) error - RemoveService(service string) error + SetDescription(meshId, description string) error + SetAlias(meshId, alias string) error + SetService(meshId, service, value string) error + RemoveService(meshId, service string) error UpdateTimeStamp() error GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider @@ -61,29 +61,33 @@ func (m *MeshManagerImpl) GetRouteManager() RouteManager { } // RemoveService implements MeshManager. -func (m *MeshManagerImpl) RemoveService(service string) error { - for _, mesh := range m.Meshes { - err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service) +func (m *MeshManagerImpl) RemoveService(meshId, service string) error { + mesh := m.GetMesh(meshId) - if err != nil { - return err - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(m.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.RemoveService(m.HostParameters.GetPublicKey(), service) } // SetService implements MeshManager. -func (m *MeshManagerImpl) SetService(service string, value string) error { - for _, mesh := range m.Meshes { - err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value) +func (m *MeshManagerImpl) SetService(meshId, service, value string) error { + mesh := m.GetMesh(meshId) - if err != nil { - return err - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(m.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.AddService(m.HostParameters.GetPublicKey(), service, value) } func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { @@ -134,6 +138,10 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) { return "", err } + if *meshConfiguration.Role == conf.CLIENT_ROLE { + return "", fmt.Errorf("cannot create mesh as a client") + } + meshId, err := m.idGenerator.GetId() var ifName string = "" @@ -348,7 +356,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { } s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) - return err } @@ -373,43 +380,36 @@ func (s *MeshManagerImpl) ApplyConfig() error { return nil } - err := s.configApplyer.ApplyConfig() - - if err != nil { - return err - } - - return nil + return s.configApplyer.ApplyConfig() } -func (s *MeshManagerImpl) SetDescription(description string) error { - meshes := s.GetMeshes() - for _, mesh := range meshes { - if mesh.NodeExists(s.HostParameters.GetPublicKey()) { - err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description) +func (s *MeshManagerImpl) SetDescription(meshId, description string) error { + mesh := s.GetMesh(meshId) - if err != nil { - return err - } - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(s.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.SetDescription(s.HostParameters.GetPublicKey(), description) } // SetAlias implements MeshManager. -func (s *MeshManagerImpl) SetAlias(alias string) error { - meshes := s.GetMeshes() - for _, mesh := range meshes { - if mesh.NodeExists(s.HostParameters.GetPublicKey()) { - err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) +func (s *MeshManagerImpl) SetAlias(meshId, alias string) error { + mesh := s.GetMesh(meshId) - if err != nil { - return err - } - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + + if !mesh.NodeExists(s.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) } // UpdateTimeStamp updates the timestamp of this node in all meshes diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index 90b0059..d621426 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -24,8 +24,8 @@ func getMeshConfiguration() *conf.DaemonConfiguration { Timeout: 5, Profile: false, StubWg: true, - SyncRate: 2, - KeepAliveTime: 60, + SyncTime: 2, + HeartBeat: 60, ClusterSize: 64, InterClusterChance: 0.15, BranchRate: 3, @@ -213,7 +213,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) { } } -func TestSetAlias(t *testing.T) { +func TestSetAliasUpdatesAliasOfNode(t *testing.T) { manager := getMeshManager() alias := "Firpo" @@ -221,14 +221,13 @@ func TestSetAlias(t *testing.T) { Port: 5000, Conf: &conf.WgConfiguration{}, }) - manager.AddSelf(&AddSelfParams{ MeshId: meshId, WgPort: 5000, Endpoint: "abc.com:8080", }) - err := manager.SetAlias(alias) + err := manager.SetAlias(meshId, alias) if err != nil { t.Fatalf(`failed to set the alias`) @@ -245,7 +244,7 @@ func TestSetAlias(t *testing.T) { } } -func TestSetDescription(t *testing.T) { +func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) { manager := getMeshManager() description := "wooooo" @@ -254,23 +253,13 @@ func TestSetDescription(t *testing.T) { Conf: &conf.WgConfiguration{}, }) - meshId2, _ := manager.CreateMesh(&CreateMeshParams{ - Port: 5001, - Conf: &conf.WgConfiguration{}, - }) - manager.AddSelf(&AddSelfParams{ MeshId: meshId1, WgPort: 5000, Endpoint: "abc.com:8080", }) - manager.AddSelf(&AddSelfParams{ - MeshId: meshId2, - WgPort: 5000, - Endpoint: "abc.com:8080", - }) - err := manager.SetDescription(description) + err := manager.SetDescription(meshId1, description) if err != nil { t.Fatalf(`failed to set the descriptions`) @@ -285,18 +274,7 @@ func TestSetDescription(t *testing.T) { if description != self1.GetDescription() { t.Fatalf(`description should be %s was %s`, description, self1.GetDescription()) } - - self2, err := manager.GetSelf(meshId2) - - if err != nil { - t.Fatalf(`failed to set the description`) - } - - if description != self2.GetDescription() { - t.Fatalf(`description should be %s was %s`, description, self2.GetDescription()) - } } - func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { manager := getMeshManager() @@ -327,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { t.Fatalf(`failed to update the timestamp`) } } + +func TestAddServiceAddsServiceToTheMesh(t *testing.T) { + manager := getMeshManager() + + meshId1, _ := manager.CreateMesh(&CreateMeshParams{ + Port: 5000, + Conf: &conf.WgConfiguration{}, + }) + manager.AddSelf(&AddSelfParams{ + MeshId: meshId1, + WgPort: 5000, + Endpoint: "abc.com:8080", + }) + + serviceName := "hello" + manager.SetService(meshId1, serviceName, "dave") + + self, err := manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; !ok { + t.Fatalf(`service not added`) + } +} + +func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) { + manager := getMeshManager() + + meshId1, _ := manager.CreateMesh(&CreateMeshParams{ + Port: 5000, + Conf: &conf.WgConfiguration{}, + }) + manager.AddSelf(&AddSelfParams{ + MeshId: meshId1, + WgPort: 5000, + Endpoint: "abc.com:8080", + }) + + serviceName := "hello" + manager.SetService(meshId1, serviceName, "dave") + + self, err := manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; !ok { + t.Fatalf(`service not added`) + } + + manager.RemoveService(meshId1, serviceName) + self, err = manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; ok { + t.Fatalf(`service still exists`) + } +} diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 2891c6f..a49f55a 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -30,8 +30,8 @@ func (*MeshNodeStub) GetType() conf.NodeType { } // GetServices implements MeshNode. -func (*MeshNodeStub) GetServices() map[string]string { - return make(map[string]string) +func (m *MeshNodeStub) GetServices() map[string]string { + return m.services } // GetAlias implements MeshNode. @@ -249,6 +249,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode { routes: make([]Route, 0), identifier: "abc", description: "A Mesh Node Stub", + services: make(map[string]string), } } @@ -271,32 +272,32 @@ type MeshManagerStub struct { // GetRouteManager implements MeshManager. func (*MeshManagerStub) GetRouteManager() RouteManager { - panic("unimplemented") + return nil } // GetNode implements MeshManager. -func (*MeshManagerStub) GetNode(string, string) MeshNode { - panic("unimplemented") +func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode { + return nil } // RemoveService implements MeshManager. -func (*MeshManagerStub) RemoveService(service string) error { - panic("unimplemented") +func (*MeshManagerStub) RemoveService(meshId, service string) error { + return nil } // SetService implements MeshManager. -func (*MeshManagerStub) SetService(service string, value string) error { - panic("unimplemented") +func (*MeshManagerStub) SetService(meshId, service, value string) error { + return nil } // SetAlias implements MeshManager. -func (*MeshManagerStub) SetAlias(alias string) error { - panic("unimplemented") +func (*MeshManagerStub) SetAlias(meshId, alias string) error { + return nil } // Close implements MeshManager. func (*MeshManagerStub) Close() error { - panic("unimplemented") + return nil } // Prune implements MeshManager. @@ -348,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error { return nil } -func (m *MeshManagerStub) SetDescription(description string) error { +func (m *MeshManagerStub) SetDescription(meshId, description string) error { return nil } diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 81bb82b..795c649 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -43,10 +43,6 @@ func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { overrideConf := getOverrideConfiguration(&args.WgArgs) - if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE { - return fmt.Errorf("cannot create a mesh with no public endpoint") - } - meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{ Port: args.WgArgs.WgPort, Conf: &overrideConf, @@ -83,10 +79,14 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { return nil } -func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { +func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error { overrideConf := getOverrideConfiguration(&args.WgArgs) - peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) + if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil { + return fmt.Errorf("user is already apart of the mesh") + } + + peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress) if err != nil { return err @@ -147,7 +147,6 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error { if err == nil { *reply = fmt.Sprintf("Left Mesh %s", meshId) } - return err } @@ -193,30 +192,34 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error { return nil } -func (n *IpcHandler) PutDescription(description string, reply *string) error { - err := n.Server.GetMeshManager().SetDescription(description) +func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error { + err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description) if err != nil { return err } - *reply = fmt.Sprintf("Set description to %s", description) + *reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId) return nil } -func (n *IpcHandler) PutAlias(alias string, reply *string) error { - err := n.Server.GetMeshManager().SetAlias(alias) +func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error { + if args.Alias == "" { + return fmt.Errorf("alias not provided") + } + + err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias) if err != nil { return err } - *reply = fmt.Sprintf("Set alias to %s", alias) + *reply = fmt.Sprintf("Set alias to %s", args.Alias) return nil } func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error { - err := n.Server.GetMeshManager().SetService(service.Service, service.Value) + err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value) if err != nil { return err @@ -226,8 +229,8 @@ func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error return nil } -func (n *IpcHandler) DeleteService(service string, reply *string) error { - err := n.Server.GetMeshManager().RemoveService(service) +func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error { + err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service) if err != nil { return err diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index a052bd4..f3b0c27 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -27,7 +27,7 @@ type SyncerImpl struct { syncCount int cluster conn.ConnCluster conf *conf.DaemonConfiguration - lastSync map[string]uint64 + lastSync map[string]int64 } // Sync: Sync with random nodes @@ -54,8 +54,8 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error { if self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 { logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId()) - // If not synchronised in certain pull from random neighbour - if uint64(time.Now().Unix())-s.lastSync[correspondingMesh.GetMeshId()] > 20 { + // If not synchronised in certain time pull from random neighbour + if s.conf.PullTime != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.conf.PullTime) { return s.Pull(self, correspondingMesh) } @@ -84,7 +84,9 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error { return nil } - redundancyLength := min(len(neighbours), 3) + // Peer with 2 nodes so that there is redundnacy in + // the situation the node leaves pre-emptively + redundancyLength := min(len(neighbours), 2) gossipNodes = neighbours[:redundancyLength] } else { neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) @@ -113,24 +115,23 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error { } if err != nil { - logging.Log.WriteInfof(err.Error()) + logging.Log.WriteErrorf(err.Error()) } } s.syncCount++ - logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before)) - logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) + logging.Log.WriteInfof("sync time: %v", time.Since(before)) + logging.Log.WriteInfof("number of syncs: %d", s.syncCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) if !succeeded { - // If could not gossip with anyone then repeat. s.infectionCount++ } correspondingMesh.SaveChanges() - s.lastSync[correspondingMesh.GetMeshId()] = uint64(time.Now().Unix()) + s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix() return nil } @@ -148,7 +149,7 @@ func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error { return nil } - logging.Log.WriteInfof("PULLING from node %s", neighbour[0]) + logging.Log.WriteInfof("pulling from node %s", neighbour[0]) pullNode, err := mesh.GetNode(neighbour[0]) @@ -159,7 +160,7 @@ func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error { err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode) if err == nil || err == io.EOF { - s.lastSync[mesh.GetMeshId()] = uint64(time.Now().Unix()) + s.lastSync[mesh.GetMeshId()] = time.Now().Unix() } else { return err } @@ -206,5 +207,5 @@ func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequest infectionCount: 0, syncCount: 0, cluster: cluster, - lastSync: make(map[string]uint64)} + lastSync: make(map[string]int64)} } diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index a3fa6da..dc9e77e 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -91,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro c := rpc.NewSyncServiceClient(client) - syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second) + syncTimeOut := float64(s.server.Conf.SyncTime) * float64(time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) defer cancel() @@ -99,11 +99,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro err = s.syncMesh(mesh, ctx, c) if err != nil { - return s.handleErr(meshId, pubKey.String(), err) + s.handleErr(meshId, pubKey.String(), err) } logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) - return nil + return err } func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index 35c1c19..54d073a 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -14,5 +14,5 @@ func syncFunction(syncer Syncer) lib.TimerFunc { } func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer { - return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate) + return lib.NewTimer(syncFunction(syncer), s.Conf.SyncTime) } diff --git a/pkg/timers/timers.go b/pkg/timers/timers.go index e26a644..3522775 100644 --- a/pkg/timers/timers.go +++ b/pkg/timers/timers.go @@ -11,5 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { logging.Log.WriteInfof("Updated Timestamp") return ctrlServer.MeshManager.UpdateTimeStamp() } - return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) + return *lib.NewTimer(timerFunc, ctrlServer.Conf.HeartBeat) }