From bb07d35dcbc13f76c4b5e054c28e0fe9d0630a8d Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Sun, 5 Nov 2023 12:08:20 +0000 Subject: [PATCH] Unit testing the automerge library and lib functions --- pkg/automerge/automerge.go | 12 +- pkg/automerge/automerge_test.go | 366 ++++++++++++++++++++++++++++++++ pkg/lib/conv.go | 4 +- pkg/lib/conv_test.go | 144 +++++++++++++ pkg/lib/random_test.go | 46 ++++ pkg/timestamp/timestamp.go | 5 +- 6 files changed, 563 insertions(+), 14 deletions(-) create mode 100644 pkg/automerge/automerge_test.go create mode 100644 pkg/lib/conv_test.go create mode 100644 pkg/lib/random_test.go diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index a349234..da7774e 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -85,16 +85,6 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro return &manager, nil } -func (c *CrdtMeshManager) removeNode(endpoint string) error { - err := c.doc.Path("nodes").Map().Delete(endpoint) - - if err != nil { - return err - } - - return nil -} - // GetNode: returns a mesh node crdt.Close releases resources used by a Client. func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { node, err := m.doc.Path("nodes").Map().Get(endpoint) @@ -176,7 +166,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro } if node.Kind() != automerge.KindMap { - return errors.New(fmt.Sprintf("%s does not exist", nodeId)) + return fmt.Errorf("%s does not exist", nodeId) } err = node.Map().Set("description", description) diff --git a/pkg/automerge/automerge_test.go b/pkg/automerge/automerge_test.go new file mode 100644 index 0000000..b25355a --- /dev/null +++ b/pkg/automerge/automerge_test.go @@ -0,0 +1,366 @@ +package crdt + +import ( + "slices" + "strings" + "testing" + "time" + + "github.com/tim-beatham/wgmesh/pkg/conf" + "github.com/tim-beatham/wgmesh/pkg/lib" + "github.com/tim-beatham/wgmesh/pkg/mesh" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type TestParams struct { + manager *CrdtMeshManager +} + +func setUpTests() *TestParams { + manager, _ := NewCrdtNodeManager(&NewCrdtNodeMangerParams{ + MeshId: "timsmesh123", + DevName: "wg0", + Port: 5000, + Client: nil, + Conf: conf.WgMeshConfiguration{}, + }) + + return &TestParams{ + manager: manager, + } +} + +func getTestNode() mesh.MeshNode { + return &MeshNodeCrdt{ + HostEndpoint: "public-endpoint:8080", + WgEndpoint: "public-endpoint:21906", + WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128", + PublicKey: "AAAAAAAAAAAA", + Timestamp: time.Now().Unix(), + Description: "A node that we are adding", + } +} + +func getTestNode2() mesh.MeshNode { + return &MeshNodeCrdt{ + HostEndpoint: "public-endpoint:8081", + WgEndpoint: "public-endpoint:21907", + WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128", + PublicKey: "BBBBBBBBB", + Timestamp: time.Now().Unix(), + Description: "A node that we are adding", + } +} + +func TestAddNodeNodeExists(t *testing.T) { + testParams := setUpTests() + testParams.manager.AddNode(getTestNode()) + + node, err := testParams.manager.GetNode("public-endpoint:8080") + + if err != nil { + t.Error(err) + } + + if node == nil { + t.Fatalf(`node not added to the mesh when it should be`) + } +} + +func TestAddNodeAddRoute(t *testing.T) { + testParams := setUpTests() + testNode := getTestNode() + testParams.manager.AddNode(testNode) + testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48") + + updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint()) + + if err != nil { + t.Error(err) + } + + if updatedNode == nil { + t.Fatalf(`Node does not exist in the mesh`) + } + + routes := updatedNode.GetRoutes() + + if !slices.Contains(routes, "fdfd:1c64:1d00::/48") { + t.Fatal("Route node added") + } + + if len(routes) != 1 { + t.Fatal(`Route length mismatch`) + } +} + +func TestGetMeshIdReturnsTheMeshId(t *testing.T) { + testParams := setUpTests() + + if len(testParams.manager.GetMeshId()) == 0 { + t.Fatal(`Meshid is less than 0`) + } +} + +// Add 2 nodes to the mesh and then get the mesh.s +// It should return the 2 nodes that have been added to the mesh +func TestAdd2NodesGetMesh(t *testing.T) { + testParams := setUpTests() + node1 := getTestNode() + node2 := getTestNode2() + + testParams.manager.AddNode(node1) + testParams.manager.AddNode(node2) + + mesh, err := testParams.manager.GetMesh() + + if err != nil { + t.Error(err) + } + + nodes := mesh.GetNodes() + + if len(nodes) != 2 { + t.Fatalf(`Mismatch in node slice`) + } + + for _, node := range nodes { + if node.GetHostEndpoint() != node1.GetHostEndpoint() && node.GetHostEndpoint() != node2.GetHostEndpoint() { + t.Fatalf(`Node should not exist`) + } + } +} + +func TestSaveMeshReturnsMeshBytes(t *testing.T) { + testParams := setUpTests() + node1 := getTestNode() + + testParams.manager.AddNode(node1) + + bytes := testParams.manager.Save() + + if len(bytes) <= 0 { + t.Fatalf(`bytes in the mesh is less than 0`) + } +} + +func TestSaveMeshThenLoad(t *testing.T) { + testParams := setUpTests() + testParams2 := setUpTests() + + node1 := getTestNode() + testParams.manager.AddNode(node1) + + bytes := testParams.manager.Save() + + err := testParams2.manager.Load(bytes) + + if err != nil { + t.Error(err) + } + + if len(bytes) <= 0 { + t.Fatalf(`bytes in the mesh is less than 0`) + } + + mesh2, err := testParams2.manager.GetMesh() + + if err != nil { + t.Error(err) + } + + nodes := mesh2.GetNodes() + + if lib.MapValues(nodes)[0].GetHostEndpoint() != node1.GetHostEndpoint() { + t.Fatalf(`Node should be in the list of nodes`) + } +} + +func TestLengthNoNodes(t *testing.T) { + testParams := setUpTests() + + if testParams.manager.Length() != 0 { + t.Fatalf(`Number of nodes should be 0`) + } +} + +func TestLength1Node(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + testParams.manager.AddNode(node) + + if testParams.manager.Length() != 1 { + t.Fatalf(`Number of nodes should be 1`) + } +} + +func TestLengthMultipleNodes(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + node1 := getTestNode() + + testParams.manager.AddNode(node) + testParams.manager.AddNode(node1) + + if testParams.manager.Length() != 2 { + t.Fatalf(`Number of nodes should be 2`) + } +} + +func TestHasChangesNoChanges(t *testing.T) { + testParams := setUpTests() + + if testParams.manager.HasChanges() { + t.Fatalf(`Should not have changes just created document`) + } +} + +func TestHasChangesChanges(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + + testParams.manager.AddNode(node) + + if !testParams.manager.HasChanges() { + t.Fatalf(`Should have changes just added node`) + } +} + +func TestHasChangesSavedChanges(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + + testParams.manager.AddNode(node) + + testParams.manager.SaveChanges() + + if testParams.manager.HasChanges() { + t.Fatalf(`Should not have changes just saved document`) + } +} + +func TestUpdateTimeStampNodeDoesNotExist(t *testing.T) { + testParams := setUpTests() + err := testParams.manager.UpdateTimeStamp("AAAAAA") + + if err == nil { + t.Fatalf(`Error should have returned`) + } +} + +func TestUpdateTimeStampNodeExists(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + + testParams.manager.AddNode(node) + err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint()) + + if err != nil { + t.Error(err) + } +} + +func TestSetDescriptionNodeDoesNotExist(t *testing.T) { + testParams := setUpTests() + err := testParams.manager.SetDescription("AAAAA", "Bob 123") + + if err == nil { + t.Fatalf(`Error should have returned`) + } +} + +func TestSetDescriptionNodeExists(t *testing.T) { + testParams := setUpTests() + node := getTestNode() + err := testParams.manager.SetDescription(node.GetHostEndpoint(), "Bob 123") + + if err == nil { + t.Fatalf(`Error should have returned`) + } +} + +func TestAddRoutesNodeDoesNotExist(t *testing.T) { + testParams := setUpTests() + + err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48") + + if err == nil { + t.Error(err) + } +} + +func TestCompareComparesByPublicKey(t *testing.T) { + node := getTestNode().(*MeshNodeCrdt) + node2 := getTestNode2().(*MeshNodeCrdt) + + if node.Compare(node2) != -1 { + t.Fatalf(`node is alphabetically before node2`) + } + + if node2.Compare(node) != 1 { + t.Fatalf(`node is alphabetical;y before node2`) + } + + if node.Compare(node) != 0 { + t.Fatalf(`node is equal to node`) + } +} + +func TestGetHostEndpoint(t *testing.T) { + node := getTestNode() + + if (node.(*MeshNodeCrdt)).HostEndpoint != node.GetHostEndpoint() { + t.Fatalf(`get hostendpoint should get the host endpoint`) + } +} + +func TestGetPublicKey(t *testing.T) { + key1, _ := wgtypes.GenerateKey() + + node := getTestNode() + node.(*MeshNodeCrdt).PublicKey = key1.String() + + pubKey, err := node.GetPublicKey() + + if err != nil { + t.Error(err) + } + + if pubKey.String() != key1.String() { + t.Fatalf(`Expected %s got %s`, key1.String(), pubKey.String()) + } +} + +func TestGetWgEndpoint(t *testing.T) { + node := getTestNode() + + if node.(*MeshNodeCrdt).WgEndpoint != node.GetWgEndpoint() { + t.Fatal(`Did not return the correct wgEndpoint`) + } +} + +func TestGetWgHost(t *testing.T) { + node := getTestNode() + + ip := node.GetWgHost() + + if node.(*MeshNodeCrdt).WgHost != ip.String() { + t.Fatal(`Did not parse WgHost correctly`) + } +} + +func TestGetTimeStamp(t *testing.T) { + node := getTestNode() + + if node.(*MeshNodeCrdt).Timestamp != node.GetTimeStamp() { + t.Fatal(`Did not return return the correct timestamp`) + } +} + +func TestGetIdentifierDoesNotContainPrefix(t *testing.T) { + node := getTestNode() + + if strings.Contains(node.GetIdentifier(), "/128") { + t.Fatal(`Identifier should not contain prefix`) + } +} diff --git a/pkg/lib/conv.go b/pkg/lib/conv.go index 65b29bb..dc73545 100644 --- a/pkg/lib/conv.go +++ b/pkg/lib/conv.go @@ -30,7 +30,7 @@ func MapKeys[K comparable, V any](m map[K]V) []K { values := make([]K, len(m)) i := 0 - for k, _ := range m { + for k := range m { values[i] = k i++ } @@ -58,7 +58,7 @@ type filterFunc[V any] func(V) bool func Filter[V any](list []V, f filterFunc[V]) []V { newList := make([]V, 0) - for _, elem := range newList { + for _, elem := range list { if f(elem) { newList = append(newList, elem) } diff --git a/pkg/lib/conv_test.go b/pkg/lib/conv_test.go new file mode 100644 index 0000000..ef73cd7 --- /dev/null +++ b/pkg/lib/conv_test.go @@ -0,0 +1,144 @@ +package lib + +import ( + "slices" + "testing" +) + +func stringToInt(input string) int { + return len(input) +} + +func intDiv(input int) int { + return input / 2 +} + +func TestMapValuesMapsValues(t *testing.T) { + values := []int{1, 4, 11, 92} + var theMap map[string]int = map[string]int{ + "mynameisjeff": values[0], + "tim": values[1], + "bob": values[2], + "derek": values[3], + } + + mapValues := MapValues(theMap) + + for _, elem := range mapValues { + if !slices.Contains(values, elem) { + t.Fatalf(`%d is not an expected value`, elem) + } + } + + if len(mapValues) != len(theMap) { + t.Fatalf(`Expected length %d got %d`, len(theMap), len(mapValues)) + } +} + +func TestMapValuesWithExcludeExcludesValues(t *testing.T) { + values := []int{1, 9, 22} + var theMap map[string]int = map[string]int{ + "mynameisbob": values[0], + "tim": values[1], + "bob": values[2], + } + + exclude := map[string]struct{}{ + "tim": {}, + } + + mapValues := MapValuesWithExclude(theMap, exclude) + + if slices.Contains(mapValues, values[1]) { + t.Fatalf(`Failed to exclude expected value`) + } + + if len(mapValues) != 2 { + t.Fatalf(`Incorrect expected length`) + } + + for _, value := range theMap { + if !slices.Contains(values, value) { + t.Fatalf(`Element does not exist in the list of + expected values`) + } + } +} + +func TestMapKeys(t *testing.T) { + keys := []string{"1", "2", "3"} + + theMap := map[string]int{ + keys[0]: 1, + keys[1]: 2, + keys[2]: 3, + } + + mapKeys := MapKeys(theMap) + + for _, elem := range mapKeys { + if !slices.Contains(keys, elem) { + t.Fatalf(`%s elem is not an expected key`, elem) + } + } + + if len(mapKeys) != len(theMap) { + t.Fatalf(`Missing expected values`) + } +} + +func TestMapValues(t *testing.T) { + array := []string{"mynameisjeff", "tim", "bob", "derek"} + + intArray := Map(array, stringToInt) + + for index, elem := range intArray { + if len(array[index]) != elem { + t.Fatalf(`Have %d want %d`, elem, len(array[index])) + } + } +} + +func TestFilterFilterAll(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + + filterFunc := func(n int) bool { + return false + } + + newValues := Filter(values, filterFunc) + + if len(newValues) != 0 { + t.Fatalf(`Expected value was 0`) + } +} + +func TestFilterFilterNone(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + + filterFunc := func(n int) bool { + return true + } + + newValues := Filter(values, filterFunc) + + if !slices.Equal(values, newValues) { + t.Fatalf(`Expected lists to be the same`) + } +} + +func TestFilterFilterSome(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + + filterFunc := func(n int) bool { + return n < 3 + } + + expected := []int{1, 2} + + actual := Filter(values, filterFunc) + + if !slices.Equal(expected, actual) { + t.Fatalf(`Expected expected and actual to be the same`) + } +} diff --git a/pkg/lib/random_test.go b/pkg/lib/random_test.go new file mode 100644 index 0000000..5c564a6 --- /dev/null +++ b/pkg/lib/random_test.go @@ -0,0 +1,46 @@ +package lib + +import ( + "slices" + "testing" +) + +// Test that a random subset of length 0 produces a zero length +// list +func TestRandomSubsetOfLength0(t *testing.T) { + values := []int{1, 2, 3, 4, 5, 6, 7, 8} + randomValues := RandomSubsetOfLength(values, 0) + + if len(randomValues) != 0 { + t.Fatalf(`Expected length to be 0`) + } +} + +func TestRandomSubsetOfLength1(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + + randomValues := RandomSubsetOfLength(values, 1) + + if len(randomValues) != 1 { + t.Fatalf(`Expected length to be 1`) + } + + if !slices.Contains(values, randomValues[0]) { + t.Fatalf(`Expected length to be 1`) + } +} + +func TestRandomSubsetEntireList(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + randomValues := RandomSubsetOfLength(values, len(values)) + + if len(randomValues) != len(values) { + t.Fatalf(`Expected length to be %d was %d`, len(values), len(randomValues)) + } + + slices.Sort(randomValues) + + if !slices.Equal(values, randomValues) { + t.Fatalf(`Expected slices to be equal`) + } +} diff --git a/pkg/timestamp/timestamp.go b/pkg/timestamp/timestamp.go index 971d138..aa87b97 100644 --- a/pkg/timestamp/timestamp.go +++ b/pkg/timestamp/timestamp.go @@ -39,7 +39,10 @@ func (s *TimeStampSchedulerImpl) Run() error { } func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler { - return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: ctrlServer.Conf.KeepAliveRate} + return &TimeStampSchedulerImpl{ + meshManager: ctrlServer.MeshManager, + updateRate: ctrlServer.Conf.KeepAliveRate, + } } func (s *TimeStampSchedulerImpl) Stop() error {