From 36e82dba476bb468375b842585543e06d0afaf2e Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Sun, 31 Dec 2023 14:25:06 +0000 Subject: [PATCH] 72-pull-rate-in-configuration - Refactored pull rate into the configuration - code freeze so no more code changes --- cmd/wg-mesh/main.go | 139 +++++++++++++++++++++------------------ pkg/api/apiserver.go | 22 +++---- pkg/conf/conf.go | 2 +- pkg/dns/dns.go | 7 +- pkg/ipc/ipc.go | 103 +++++++++++++++++++++++++---- pkg/mesh/manager.go | 88 ++++++++++++------------- pkg/mesh/manager_test.go | 95 ++++++++++++++++++-------- pkg/mesh/stub_types.go | 27 ++++---- pkg/robin/requester.go | 26 ++++---- 9 files changed, 318 insertions(+), 191 deletions(-) diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 2894326..ab759d5 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -22,25 +22,22 @@ type CreateMeshParams struct { AdvertiseDefault bool } -func createMesh(params *CreateMeshParams) string { +func createMesh(client *ipc.ClientIpc, args *ipc.NewMeshArgs) { var reply string - newMeshParams := ipc.NewMeshArgs{ - WgArgs: params.WgArgs, - } - - err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply) + err := client.CreateMesh(args, &reply) if err != nil { - return err.Error() + fmt.Println(err.Error()) + return } - return reply + fmt.Println(reply) } -func listMeshes(client *ipcRpc.Client) { +func listMeshes(client *ipc.ClientIpc) { reply := new(ipc.ListMeshReply) - err := client.Call("IpcHandler.ListMeshes", "", &reply) + err := client.ListMeshes(reply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -52,38 +49,22 @@ func listMeshes(client *ipcRpc.Client) { } } -type JoinMeshParams struct { - Client *ipcRpc.Client - MeshId string - IpAddress string - Endpoint string - WgArgs ipc.WireGuardArgs - AdvertiseRoutes bool - AdvertiseDefault bool -} - -func joinMesh(params *JoinMeshParams) string { +func joinMesh(client *ipc.ClientIpc, args ipc.JoinMeshArgs) { var reply string - args := ipc.JoinMeshArgs{ - MeshId: params.MeshId, - IpAdress: params.IpAddress, - WgArgs: params.WgArgs, - } - - err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply) + err := client.JoinMesh(args, &reply) if err != nil { - return err.Error() + fmt.Println(err.Error()) } - return reply + fmt.Println(reply) } -func leaveMesh(client *ipcRpc.Client, meshId string) { +func leaveMesh(client *ipc.ClientIpc, meshId string) { var reply string - err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply) + err := client.LeaveMesh(meshId, &reply) if err != nil { fmt.Println(err.Error()) @@ -93,10 +74,10 @@ func leaveMesh(client *ipcRpc.Client, meshId string) { fmt.Println(reply) } -func getGraph(client *ipcRpc.Client) { +func getGraph(client *ipc.ClientIpc) { listMeshesReply := new(ipc.ListMeshReply) - err := client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) + err := client.ListMeshes(listMeshesReply) if err != nil { fmt.Println(err.Error()) @@ -108,7 +89,7 @@ func getGraph(client *ipcRpc.Client) { for _, meshId := range listMeshesReply.Meshes { var meshReply ipc.GetMeshReply - err := client.Call("IpcHandler.GetMesh", &meshId, &meshReply) + err := client.GetMesh(meshId, &meshReply) if err != nil { fmt.Println(err.Error()) @@ -129,10 +110,15 @@ func getGraph(client *ipcRpc.Client) { fmt.Println(dot) } -func queryMesh(client *ipcRpc.Client, meshId, query string) { +func queryMesh(client *ipc.ClientIpc, meshId, query string) { var reply string - err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply) + args := ipc.QueryMesh{ + MeshId: meshId, + Query: query, + } + + err := client.Query(args, &reply) if err != nil { fmt.Println(err.Error()) @@ -142,11 +128,13 @@ func queryMesh(client *ipcRpc.Client, meshId, query string) { fmt.Println(reply) } -// putDescription: puts updates the description about the node to the meshes -func putDescription(client *ipcRpc.Client, description string) { +func putDescription(client *ipc.ClientIpc, meshId, description string) { var reply string - err := client.Call("IpcHandler.PutDescription", &description, &reply) + err := client.PutDescription(ipc.PutDescriptionArgs{ + MeshId: meshId, + Description: description, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -157,10 +145,13 @@ func putDescription(client *ipcRpc.Client, description string) { } // putAlias: puts an alias for the node -func putAlias(client *ipcRpc.Client, alias string) { +func putAlias(client *ipc.ClientIpc, meshid, alias string) { var reply string - err := client.Call("IpcHandler.PutAlias", &alias, &reply) + err := client.PutAlias(ipc.PutAliasArgs{ + MeshId: meshid, + Alias: alias, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -170,15 +161,14 @@ func putAlias(client *ipcRpc.Client, alias string) { fmt.Println(reply) } -func setService(client *ipcRpc.Client, service, value string) { +func setService(client *ipc.ClientIpc, meshId, service, value string) { var reply string - serviceArgs := &ipc.PutServiceArgs{ + err := client.PutService(ipc.PutServiceArgs{ + MeshId: meshId, Service: service, Value: value, - } - - err := client.Call("IpcHandler.PutService", serviceArgs, &reply) + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -188,10 +178,13 @@ func setService(client *ipcRpc.Client, service, value string) { fmt.Println(reply) } -func deleteService(client *ipcRpc.Client, service string) { +func deleteService(client *ipc.ClientIpc, meshId, service string) { var reply string - err := client.Call("IpcHandler.PutService", &service, &reply) + err := client.DeleteService(ipc.DeleteServiceArgs{ + MeshId: meshId, + Service: service, + }, &reply) if err != nil { fmt.Println(err.Error()) @@ -300,6 +293,16 @@ func main() { Help: "Description of the node in the mesh", }) + var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + + var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{ Required: true, Help: "Alias of the node to set can be used in DNS to lookup an IP address", @@ -314,11 +317,21 @@ func main() { Help: "Value of the service to advertise in the mesh network", }) + var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{ Required: true, Help: "Key of the service to remove", }) + var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{ + Required: true, + Help: "MeshID of the mesh network to join", + }) + err := parser.Parse(os.Args) if err != nil { @@ -326,16 +339,13 @@ func main() { return } - client, err := ipcRpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { - fmt.Println(err.Error()) - return + panic(err) } if newMeshCmd.Happened() { - fmt.Println(createMesh(&CreateMeshParams{ - Client: client, - Endpoint: *newMeshEndpoint, + args := &ipc.NewMeshArgs{ WgArgs: ipc.WireGuardArgs{ Endpoint: *newMeshEndpoint, Role: *newMeshRole, @@ -344,7 +354,9 @@ func main() { AdvertiseDefaultRoute: *newMeshAdvertiseDefaults, AdvertiseRoutes: *newMeshAdvertiseRoutes, }, - })) + } + + createMesh(client, args) } if listMeshCmd.Happened() { @@ -352,11 +364,9 @@ func main() { } if joinMeshCmd.Happened() { - fmt.Println(joinMesh(&JoinMeshParams{ - Client: client, + args := ipc.JoinMeshArgs{ IpAddress: *joinMeshIpAddress, MeshId: *joinMeshId, - Endpoint: *joinMeshEndpoint, WgArgs: ipc.WireGuardArgs{ Endpoint: *joinMeshEndpoint, Role: *joinMeshRole, @@ -365,7 +375,8 @@ func main() { AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults, AdvertiseRoutes: *joinMeshAdvertiseRoutes, }, - })) + } + joinMesh(client, args) } if getGraphCmd.Happened() { @@ -381,18 +392,18 @@ func main() { } if putDescriptionCmd.Happened() { - putDescription(client, *description) + putDescription(client, *descriptionMeshId, *description) } if putAliasCmd.Happened() { - putAlias(client, *alias) + putAlias(client, *aliasMeshId, *alias) } if setServiceCmd.Happened() { - setService(client, *serviceKey, *serviceValue) + setService(client, *serviceMeshId, *serviceKey, *serviceValue) } if deleteServiceCmd.Happened() { - deleteService(client, *deleteServiceKey) + deleteService(client, *deleteServiceMeshid, *deleteServiceKey) } } diff --git a/pkg/api/apiserver.go b/pkg/api/apiserver.go index d11fc56..46f665d 100644 --- a/pkg/api/apiserver.go +++ b/pkg/api/apiserver.go @@ -4,8 +4,6 @@ import ( "fmt" "net/http" - ipcRpc "net/rpc" - "github.com/gin-gonic/gin" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -13,8 +11,6 @@ import ( "github.com/tim-beatham/wgmesh/pkg/what8words" ) -const SockAddr = "/tmp/wgmesh_ipc.sock" - type ApiServer interface { GetMeshes(c *gin.Context) Run(addr string) error @@ -22,7 +18,7 @@ type ApiServer interface { type SmegServer struct { router *gin.Engine - client *ipcRpc.Client + client *ipc.ClientIpc words *what8words.What8Words } @@ -106,7 +102,7 @@ func (s *SmegServer) CreateMesh(c *gin.Context) { var reply string - err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply) + err := s.client.CreateMesh(&ipcRequest, &reply) if err != nil { c.JSON(http.StatusBadRequest, &gin.H{ @@ -132,8 +128,8 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { } ipcRequest := ipc.JoinMeshArgs{ - MeshId: joinMesh.MeshId, - IpAdress: joinMesh.Bootstrap, + MeshId: joinMesh.MeshId, + IpAddress: joinMesh.Bootstrap, WgArgs: ipc.WireGuardArgs{ WgPort: joinMesh.WgPort, }, @@ -141,7 +137,7 @@ func (s *SmegServer) JoinMesh(c *gin.Context) { var reply string - err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply) + err := s.client.JoinMesh(ipcRequest, &reply) if err != nil { c.JSON(http.StatusBadRequest, &gin.H{ @@ -164,7 +160,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) { getMeshReply := new(ipc.GetMeshReply) - err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply) + err := s.client.GetMesh(meshid, getMeshReply) if err != nil { c.JSON(http.StatusNotFound, @@ -182,7 +178,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) { func (s *SmegServer) GetMeshes(c *gin.Context) { listMeshesReply := new(ipc.ListMeshReply) - err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) + err := s.client.ListMeshes(listMeshesReply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -195,7 +191,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) { for _, mesh := range listMeshesReply.Meshes { getMeshReply := new(ipc.GetMeshReply) - err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply) + err := s.client.GetMesh(mesh, getMeshReply) if err != nil { logging.Log.WriteErrorf(err.Error()) @@ -215,7 +211,7 @@ func (s *SmegServer) Run(addr string) error { } func NewSmegServer(conf ApiServerConf) (ApiServer, error) { - client, err := ipcRpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { return nil, err diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 3276a78..e0bf283 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -79,7 +79,7 @@ type DaemonConfiguration struct { // SyncTime specifies how long the minimum time should be between synchronisation SyncTime int `yaml:"syncTime" validate:"required,gte=1"` // PullTime specifies the interval between checking for configuration changes - PullTime int `yaml:"pullTime" validate:"required,gte=0"` + PullTime int `yaml:"pullTime" validate:"gte=0"` // HeartBeat: number of seconds before the leader of the mesh sends an update to // send to every member in the mesh HeartBeat int `yaml:"heartBeatTime" validate:"required,gte=1"` diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 8847ee0..2ebfb5a 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "net" - "net/rpc" "github.com/miekg/dns" "github.com/tim-beatham/wgmesh/pkg/ipc" @@ -18,7 +17,7 @@ const SockAddr = "/tmp/wgmesh_ipc.sock" const MeshRegularExpression = `(?P.+)\.(?P.+)\.smeg\.` type DNSHandler struct { - client *rpc.Client + client *ipc.ClientIpc server *dns.Server } @@ -27,7 +26,7 @@ type DNSHandler struct { func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { var reply string - err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ + err := d.client.Query(ipc.QueryMesh{ MeshId: meshId, Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), }, &reply) @@ -97,7 +96,7 @@ func (h *DNSHandler) Close() error { } func NewDns(udpPort int) (*DNSHandler, error) { - client, err := rpc.DialHTTP("unix", SockAddr) + client, err := ipc.NewClientIpc() if err != nil { return nil, err diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index b8896ee..895af93 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -5,11 +5,27 @@ import ( "net" "net/http" "net/rpc" + ipcRpc "net/rpc" "os" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ) +const SockAddr = "/tmp/wgmesh_sock" + +type MeshIpc interface { + CreateMesh(args *NewMeshArgs, reply *string) error + ListMeshes(name string, reply *ListMeshReply) error + JoinMesh(args *JoinMeshArgs, reply *string) error + LeaveMesh(meshId string, reply *string) error + GetMesh(meshId string, reply *GetMeshReply) error + Query(query QueryMesh, reply *string) error + PutDescription(args PutDescriptionArgs, reply *string) error + PutAlias(args PutAliasArgs, reply *string) error + PutService(args PutServiceArgs, reply *string) error + DeleteService(args DeleteServiceArgs, reply *string) error +} + // WireGuardArgs are provided args specific to WireGuard type WireGuardArgs struct { // WgPort is the WireGuard port to expose @@ -39,7 +55,7 @@ type JoinMeshArgs struct { // MeshId is the ID of the mesh to join MeshId string // IpAddress is a routable IP in another mesh - IpAdress string + IpAddress string // WgArgs is the WireGuard parameters to use. WgArgs WireGuardArgs } @@ -47,6 +63,22 @@ type JoinMeshArgs struct { type PutServiceArgs struct { Service string Value string + MeshId string +} + +type DeleteServiceArgs struct { + Service string + MeshId string +} + +type PutAliasArgs struct { + Alias string + MeshId string +} + +type PutDescriptionArgs struct { + Description string + MeshId string } type GetMeshReply struct { @@ -62,20 +94,65 @@ type QueryMesh struct { Query string } -type MeshIpc interface { - CreateMesh(args *NewMeshArgs, reply *string) error - ListMeshes(name string, reply *ListMeshReply) error - JoinMesh(args JoinMeshArgs, reply *string) error - LeaveMesh(meshId string, reply *string) error - GetMesh(meshId string, reply *GetMeshReply) error - Query(query QueryMesh, reply *string) error - PutDescription(description string, reply *string) error - PutAlias(alias string, reply *string) error - PutService(args PutServiceArgs, reply *string) error - DeleteService(service string, reply *string) error +type ClientIpc struct { + client *ipcRpc.Client } -const SockAddr = "/tmp/wgmesh_ipc.sock" +func NewClientIpc() (*ClientIpc, error) { + client, err := ipcRpc.DialHTTP("unix", SockAddr) + + if err != nil { + return nil, err + } + + return &ClientIpc{ + client: client, + }, nil +} + +func (c *ClientIpc) CreateMesh(args *NewMeshArgs, reply *string) error { + return c.client.Call("IpcHandler.CreateMesh", args, reply) +} + +func (c *ClientIpc) ListMeshes(reply *ListMeshReply) error { + return c.client.Call("IpcHandler.ListMeshes", "", reply) +} + +func (c *ClientIpc) JoinMesh(args JoinMeshArgs, reply *string) error { + return c.client.Call("IpcHandler.JoinMesh", &args, reply) +} + +func (c *ClientIpc) LeaveMesh(meshId string, reply *string) error { + return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply) +} + +func (c *ClientIpc) GetMesh(meshId string, reply *GetMeshReply) error { + return c.client.Call("IpcHandler.GetMesh", &meshId, reply) +} + +func (c *ClientIpc) Query(query QueryMesh, reply *string) error { + return c.client.Call("IpcHandler.Query", &query, reply) +} + +func (c *ClientIpc) PutDescription(args PutDescriptionArgs, reply *string) error { + return c.client.Call("IpcHandler.PutDescription", &args, reply) +} + +func (c *ClientIpc) PutAlias(args PutAliasArgs, reply *string) error { + return c.client.Call("IpcHandler.PutAlias", &args, reply) +} + +func (c *ClientIpc) PutService(args PutServiceArgs, reply *string) error { + return c.client.Call("IpcHandler.PutService", &args, reply) +} + +func (c *ClientIpc) DeleteService(args DeleteServiceArgs, reply *string) error { + return c.client.Call("IpcHandler.DeleteService", &args, reply) +} + +func (c *ClientIpc) Close() error { + return c.Close() +} func RunIpcHandler(server MeshIpc) error { if err := os.RemoveAll(SockAddr); err != nil { diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index ee82101..efd60c3 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -24,10 +24,10 @@ type MeshManager interface { LeaveMesh(meshId string) error GetSelf(meshId string) (MeshNode, error) ApplyConfig() error - SetDescription(description string) error - SetAlias(alias string) error - SetService(service string, value string) error - RemoveService(service string) error + SetDescription(meshId, description string) error + SetAlias(meshId, alias string) error + SetService(meshId, service, value string) error + RemoveService(meshId, service string) error UpdateTimeStamp() error GetClient() *wgctrl.Client GetMeshes() map[string]MeshProvider @@ -61,29 +61,33 @@ func (m *MeshManagerImpl) GetRouteManager() RouteManager { } // RemoveService implements MeshManager. -func (m *MeshManagerImpl) RemoveService(service string) error { - for _, mesh := range m.Meshes { - err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service) +func (m *MeshManagerImpl) RemoveService(meshId, service string) error { + mesh := m.GetMesh(meshId) - if err != nil { - return err - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(m.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.RemoveService(m.HostParameters.GetPublicKey(), service) } // SetService implements MeshManager. -func (m *MeshManagerImpl) SetService(service string, value string) error { - for _, mesh := range m.Meshes { - err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value) +func (m *MeshManagerImpl) SetService(meshId, service, value string) error { + mesh := m.GetMesh(meshId) - if err != nil { - return err - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(m.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.AddService(m.HostParameters.GetPublicKey(), service, value) } func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { @@ -352,7 +356,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error { } s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) - return err } @@ -377,43 +380,36 @@ func (s *MeshManagerImpl) ApplyConfig() error { return nil } - err := s.configApplyer.ApplyConfig() - - if err != nil { - return err - } - - return nil + return s.configApplyer.ApplyConfig() } -func (s *MeshManagerImpl) SetDescription(description string) error { - meshes := s.GetMeshes() - for _, mesh := range meshes { - if mesh.NodeExists(s.HostParameters.GetPublicKey()) { - err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description) +func (s *MeshManagerImpl) SetDescription(meshId, description string) error { + mesh := s.GetMesh(meshId) - if err != nil { - return err - } - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + if !mesh.NodeExists(s.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.SetDescription(s.HostParameters.GetPublicKey(), description) } // SetAlias implements MeshManager. -func (s *MeshManagerImpl) SetAlias(alias string) error { - meshes := s.GetMeshes() - for _, mesh := range meshes { - if mesh.NodeExists(s.HostParameters.GetPublicKey()) { - err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) +func (s *MeshManagerImpl) SetAlias(meshId, alias string) error { + mesh := s.GetMesh(meshId) - if err != nil { - return err - } - } + if mesh == nil { + return fmt.Errorf("mesh %s does not exist", meshId) } - return nil + + if !mesh.NodeExists(s.HostParameters.GetPublicKey()) { + return fmt.Errorf("node %s does not exist in the mesh", meshId) + } + + return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) } // UpdateTimeStamp updates the timestamp of this node in all meshes diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index 46af9c2..d621426 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -213,7 +213,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) { } } -func TestSetAlias(t *testing.T) { +func TestSetAliasUpdatesAliasOfNode(t *testing.T) { manager := getMeshManager() alias := "Firpo" @@ -221,14 +221,13 @@ func TestSetAlias(t *testing.T) { Port: 5000, Conf: &conf.WgConfiguration{}, }) - manager.AddSelf(&AddSelfParams{ MeshId: meshId, WgPort: 5000, Endpoint: "abc.com:8080", }) - err := manager.SetAlias(alias) + err := manager.SetAlias(meshId, alias) if err != nil { t.Fatalf(`failed to set the alias`) @@ -245,7 +244,7 @@ func TestSetAlias(t *testing.T) { } } -func TestSetDescription(t *testing.T) { +func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) { manager := getMeshManager() description := "wooooo" @@ -254,23 +253,13 @@ func TestSetDescription(t *testing.T) { Conf: &conf.WgConfiguration{}, }) - meshId2, _ := manager.CreateMesh(&CreateMeshParams{ - Port: 5001, - Conf: &conf.WgConfiguration{}, - }) - manager.AddSelf(&AddSelfParams{ MeshId: meshId1, WgPort: 5000, Endpoint: "abc.com:8080", }) - manager.AddSelf(&AddSelfParams{ - MeshId: meshId2, - WgPort: 5000, - Endpoint: "abc.com:8080", - }) - err := manager.SetDescription(description) + err := manager.SetDescription(meshId1, description) if err != nil { t.Fatalf(`failed to set the descriptions`) @@ -285,18 +274,7 @@ func TestSetDescription(t *testing.T) { if description != self1.GetDescription() { t.Fatalf(`description should be %s was %s`, description, self1.GetDescription()) } - - self2, err := manager.GetSelf(meshId2) - - if err != nil { - t.Fatalf(`failed to set the description`) - } - - if description != self2.GetDescription() { - t.Fatalf(`description should be %s was %s`, description, self2.GetDescription()) - } } - func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { manager := getMeshManager() @@ -327,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { t.Fatalf(`failed to update the timestamp`) } } + +func TestAddServiceAddsServiceToTheMesh(t *testing.T) { + manager := getMeshManager() + + meshId1, _ := manager.CreateMesh(&CreateMeshParams{ + Port: 5000, + Conf: &conf.WgConfiguration{}, + }) + manager.AddSelf(&AddSelfParams{ + MeshId: meshId1, + WgPort: 5000, + Endpoint: "abc.com:8080", + }) + + serviceName := "hello" + manager.SetService(meshId1, serviceName, "dave") + + self, err := manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; !ok { + t.Fatalf(`service not added`) + } +} + +func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) { + manager := getMeshManager() + + meshId1, _ := manager.CreateMesh(&CreateMeshParams{ + Port: 5000, + Conf: &conf.WgConfiguration{}, + }) + manager.AddSelf(&AddSelfParams{ + MeshId: meshId1, + WgPort: 5000, + Endpoint: "abc.com:8080", + }) + + serviceName := "hello" + manager.SetService(meshId1, serviceName, "dave") + + self, err := manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; !ok { + t.Fatalf(`service not added`) + } + + manager.RemoveService(meshId1, serviceName) + self, err = manager.GetSelf(meshId1) + + if err != nil { + t.Fatalf(`error thrown %s:`, err.Error()) + } + + if _, ok := self.GetServices()[serviceName]; ok { + t.Fatalf(`service still exists`) + } +} diff --git a/pkg/mesh/stub_types.go b/pkg/mesh/stub_types.go index 2891c6f..a49f55a 100644 --- a/pkg/mesh/stub_types.go +++ b/pkg/mesh/stub_types.go @@ -30,8 +30,8 @@ func (*MeshNodeStub) GetType() conf.NodeType { } // GetServices implements MeshNode. -func (*MeshNodeStub) GetServices() map[string]string { - return make(map[string]string) +func (m *MeshNodeStub) GetServices() map[string]string { + return m.services } // GetAlias implements MeshNode. @@ -249,6 +249,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode { routes: make([]Route, 0), identifier: "abc", description: "A Mesh Node Stub", + services: make(map[string]string), } } @@ -271,32 +272,32 @@ type MeshManagerStub struct { // GetRouteManager implements MeshManager. func (*MeshManagerStub) GetRouteManager() RouteManager { - panic("unimplemented") + return nil } // GetNode implements MeshManager. -func (*MeshManagerStub) GetNode(string, string) MeshNode { - panic("unimplemented") +func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode { + return nil } // RemoveService implements MeshManager. -func (*MeshManagerStub) RemoveService(service string) error { - panic("unimplemented") +func (*MeshManagerStub) RemoveService(meshId, service string) error { + return nil } // SetService implements MeshManager. -func (*MeshManagerStub) SetService(service string, value string) error { - panic("unimplemented") +func (*MeshManagerStub) SetService(meshId, service, value string) error { + return nil } // SetAlias implements MeshManager. -func (*MeshManagerStub) SetAlias(alias string) error { - panic("unimplemented") +func (*MeshManagerStub) SetAlias(meshId, alias string) error { + return nil } // Close implements MeshManager. func (*MeshManagerStub) Close() error { - panic("unimplemented") + return nil } // Prune implements MeshManager. @@ -348,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error { return nil } -func (m *MeshManagerStub) SetDescription(description string) error { +func (m *MeshManagerStub) SetDescription(meshId, description string) error { return nil } diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 5d86f80..795c649 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -79,14 +79,14 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error { return nil } -func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { +func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error { overrideConf := getOverrideConfiguration(&args.WgArgs) if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil { return fmt.Errorf("user is already apart of the mesh") } - peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) + peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress) if err != nil { return err @@ -192,30 +192,34 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error { return nil } -func (n *IpcHandler) PutDescription(description string, reply *string) error { - err := n.Server.GetMeshManager().SetDescription(description) +func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error { + err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description) if err != nil { return err } - *reply = fmt.Sprintf("Set description to %s", description) + *reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId) return nil } -func (n *IpcHandler) PutAlias(alias string, reply *string) error { - err := n.Server.GetMeshManager().SetAlias(alias) +func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error { + if args.Alias == "" { + return fmt.Errorf("alias not provided") + } + + err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias) if err != nil { return err } - *reply = fmt.Sprintf("Set alias to %s", alias) + *reply = fmt.Sprintf("Set alias to %s", args.Alias) return nil } func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error { - err := n.Server.GetMeshManager().SetService(service.Service, service.Value) + err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value) if err != nil { return err @@ -225,8 +229,8 @@ func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error return nil } -func (n *IpcHandler) DeleteService(service string, reply *string) error { - err := n.Server.GetMeshManager().RemoveService(service) +func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error { + err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service) if err != nil { return err