diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index b62db73..11cadca 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -40,7 +40,11 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { crdt.Services = make(map[string]string) crdt.Timestamp = time.Now().Unix() - c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) + err := c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) + + if err != nil { + logging.Log.WriteInfof("error") + } } func (c *CrdtMeshManager) isPeer(nodeId string) bool { @@ -161,7 +165,7 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { node, err := m.doc.Path("nodes").Map().Get(endpoint) if node.Kind() != automerge.KindMap { - return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type") + return nil, fmt.Errorf("getnode: node is not a map") } if err != nil { diff --git a/pkg/automerge/automerge_test.go b/pkg/automerge/automerge_test.go index 4c74e93..66eb1b4 100644 --- a/pkg/automerge/automerge_test.go +++ b/pkg/automerge/automerge_test.go @@ -1,7 +1,7 @@ package automerge import ( - "slices" + "net" "strings" "testing" "time" @@ -22,7 +22,7 @@ func setUpTests() *TestParams { DevName: "wg0", Port: 5000, Client: nil, - Conf: conf.DaemonConfiguration{}, + Conf: &conf.WgConfiguration{}, }) return &TestParams{ @@ -31,22 +31,26 @@ func setUpTests() *TestParams { } func getTestNode() mesh.MeshNode { + pubKey, _ := wgtypes.GeneratePrivateKey() + return &MeshNodeCrdt{ HostEndpoint: "public-endpoint:8080", WgEndpoint: "public-endpoint:21906", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128", - PublicKey: "AAAAAAAAAAAA", + PublicKey: pubKey.String(), Timestamp: time.Now().Unix(), Description: "A node that we are adding", } } func getTestNode2() mesh.MeshNode { + pubKey, _ := wgtypes.GeneratePrivateKey() + return &MeshNodeCrdt{ HostEndpoint: "public-endpoint:8081", WgEndpoint: "public-endpoint:21907", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128", - PublicKey: "BBBBBBBBB", + PublicKey: pubKey.String(), Timestamp: time.Now().Unix(), Description: "A node that we are adding", } @@ -54,9 +58,11 @@ func getTestNode2() mesh.MeshNode { func TestAddNodeNodeExists(t *testing.T) { testParams := setUpTests() - testParams.manager.AddNode(getTestNode()) + node := getTestNode() + testParams.manager.AddNode(node) - node, err := testParams.manager.GetNode("public-endpoint:8080") + pubKey, _ := node.GetPublicKey() + node, err := testParams.manager.GetNode(pubKey.String()) if err != nil { t.Error(err) @@ -70,25 +76,28 @@ func TestAddNodeNodeExists(t *testing.T) { func TestAddNodeAddRoute(t *testing.T) { testParams := setUpTests() testNode := getTestNode() - testParams.manager.AddNode(testNode) - testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48") + pubKey, _ := testNode.GetPublicKey() - updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint()) + _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48") + + testParams.manager.AddNode(testNode) + testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{ + Destination: destination, + HopCount: 0, + Path: make([]string, 0), + }) + updatedNode, err := testParams.manager.GetNode(pubKey.String()) if err != nil { t.Error(err) } if updatedNode == nil { - t.Fatalf(`Node does not exist in the mesh`) + t.Fatalf(`node does not exist in the mesh`) } routes := updatedNode.GetRoutes() - if !slices.Contains(routes, "fd:1c64:1d00::/48") { - t.Fatal("Route node not added") - } - if len(routes) != 1 { t.Fatal(`Route length mismatch`) } @@ -253,7 +262,9 @@ func TestUpdateTimeStampNodeExists(t *testing.T) { node := getTestNode() testParams.manager.AddNode(node) - err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint()) + pubKey, _ := node.GetPublicKey() + + err := testParams.manager.UpdateTimeStamp(pubKey.String()) if err != nil { t.Error(err) @@ -282,7 +293,13 @@ func TestSetDescriptionNodeExists(t *testing.T) { func TestAddRoutesNodeDoesNotExist(t *testing.T) { testParams := setUpTests() - err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48") + _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48") + + err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{ + Destination: destination, + HopCount: 0, + Path: make([]string, 0), + }) if err == nil { t.Error(err) @@ -293,16 +310,11 @@ func TestCompareComparesByPublicKey(t *testing.T) { node := getTestNode().(*MeshNodeCrdt) node2 := getTestNode2().(*MeshNodeCrdt) - if node.Compare(node2) != -1 { - t.Fatalf(`node is alphabetically before node2`) - } + pubKey1, _ := node.GetPublicKey() + pubKey2, _ := node2.GetPublicKey() - if node2.Compare(node) != 1 { - t.Fatalf(`node is alphabetical;y before node2`) - } - - if node.Compare(node) != 0 { - t.Fatalf(`node is equal to node`) + if node.Compare(node2) != strings.Compare(pubKey1.String(), pubKey2.String()) { + t.Fatalf(`compare failed`) } } diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index 2778e57..7758268 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -28,7 +28,7 @@ type MeshNodeFactory struct { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { hostName := f.getAddress(params) - grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) + grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort) if *params.MeshConfig.Role == conf.CLIENT_ROLE { grpcEndpoint = "-" diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index 6facf82..45c8138 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -1,13 +1,40 @@ package conf -import "testing" +import ( + "testing" +) func getExampleConfiguration() *DaemonConfiguration { + discovery := PUBLIC_IP_DISCOVERY + advertiseRoutes := false + advertiseDefaultRoute := false + endpoint := "abc.com:123" + nodeType := CLIENT_ROLE + keepAliveWg := 0 + return &DaemonConfiguration{ - CertificatePath: "./cert/cert.pem", - PrivateKeyPath: "./cert/key.pem", - CaCertificatePath: "./cert/ca.pems", + CertificatePath: "../../../cert/cert.pem", + PrivateKeyPath: "../../../cert/priv.pem", + CaCertificatePath: "../../../cert/cacert.pem", SkipCertVerification: true, + GrpcPort: 25, + Timeout: 5, + Profile: false, + StubWg: false, + SyncRate: 2, + KeepAliveTime: 2, + ClusterSize: 64, + InterClusterChance: 0.15, + BranchRate: 3, + InfectionCount: 2, + BaseConfiguration: WgConfiguration{ + IPDiscovery: &discovery, + AdvertiseRoutes: &advertiseRoutes, + AdvertiseDefaultRoute: &advertiseDefaultRoute, + Endpoint: &endpoint, + Role: &nodeType, + KeepAliveWg: &keepAliveWg, + }, } } @@ -55,9 +82,141 @@ func TestConfigurationGrpcPortEmpty(t *testing.T) { } } +func TestIPDiscoveryNotSet(t *testing.T) { + conf := getExampleConfiguration() + ipDiscovery := IPDiscovery("djdsjdskd") + conf.BaseConfiguration.IPDiscovery = &ipDiscovery + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestAdvertiseRoutesNotSet(t *testing.T) { + conf := getExampleConfiguration() + conf.BaseConfiguration.AdvertiseRoutes = nil + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestAdvertiseDefaultRouteNotSet(t *testing.T) { + conf := getExampleConfiguration() + conf.BaseConfiguration.AdvertiseDefaultRoute = nil + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestKeepAliveWgNegative(t *testing.T) { + conf := getExampleConfiguration() + keepAliveWg := -1 + conf.BaseConfiguration.KeepAliveWg = &keepAliveWg + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestRoleTypeNotValid(t *testing.T) { + conf := getExampleConfiguration() + role := NodeType("bruhhh") + conf.BaseConfiguration.Role = &role + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestRoleTypeNotSpecified(t *testing.T) { + conf := getExampleConfiguration() + conf.BaseConfiguration.Role = nil + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`invalid role type`) + } +} + +func TestBranchRateZero(t *testing.T) { + conf := getExampleConfiguration() + conf.BranchRate = 0 + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestSyncRateZero(t *testing.T) { + conf := getExampleConfiguration() + conf.SyncRate = 0 + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestKeepAliveTimeZero(t *testing.T) { + conf := getExampleConfiguration() + conf.KeepAliveTime = 0 + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestClusterSizeZero(t *testing.T) { + conf := getExampleConfiguration() + conf.ClusterSize = 0 + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestInterClusterChanceZero(t *testing.T) { + conf := getExampleConfiguration() + conf.InterClusterChance = 0 + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + +func TestInfectionCountOne(t *testing.T) { + conf := getExampleConfiguration() + conf.InfectionCount = 0 + + err := ValidateDaemonConfiguration(conf) + + if err == nil { + t.Fatal(`error should be thrown`) + } +} + func TestValidConfiguration(t *testing.T) { conf := getExampleConfiguration() - err := ValidateDaemonConfiguration(conf) if err != nil { diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index f17a849..90b0059 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -23,7 +23,7 @@ func getMeshConfiguration() *conf.DaemonConfiguration { SkipCertVerification: true, Timeout: 5, Profile: false, - StubWg: false, + StubWg: true, SyncRate: 2, KeepAliveTime: 60, ClusterSize: 64, diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index b04f453..bbe9afc 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -20,6 +20,12 @@ type Route interface { GetPath() []string } +func RouteEqual(r1 Route, r2 Route) bool { + return r1.GetDestination().IP.Equal(r2.GetDestination().IP) && + r1.GetHopCount() == r2.GetHopCount() && + slices.Equal(r1.GetPath(), r2.GetPath()) +} + func RouteEquals(r1, r2 Route) bool { return r1.GetDestination().String() == r2.GetDestination().String() && r1.GetHopCount() == r2.GetHopCount() && diff --git a/pkg/robin/requester_test.go b/pkg/robin/requester_test.go index 548691c..89a222c 100644 --- a/pkg/robin/requester_test.go +++ b/pkg/robin/requester_test.go @@ -3,6 +3,7 @@ package robin import ( "testing" + "github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/mesh" @@ -17,9 +18,11 @@ func TestCreateMeshRepliesMeshId(t *testing.T) { requester := getRequester() err := requester.CreateMesh(&ipc.NewMeshArgs{ - IfName: "wg0", - WgPort: 5000, - Endpoint: "abc.com", + WgArgs: ipc.WireGuardArgs{ + WgPort: 500, + Endpoint: "abc.com:1234", + Role: "peer", + }, }, &reply) if err != nil { @@ -52,9 +55,8 @@ func TestListMeshesMeshesNotEmpty(t *testing.T) { requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ MeshId: "tim123", - DevName: "wg0", - WgPort: 5000, MeshBytes: make([]byte, 0), + Conf: &conf.WgConfiguration{}, }) err := requester.ListMeshes("", &reply)