diff --git a/cmd/smegd/main.go b/cmd/smegd/main.go index f8725df..c5db5a7 100644 --- a/cmd/smegd/main.go +++ b/cmd/smegd/main.go @@ -16,17 +16,20 @@ import ( ) func main() { + if len(os.Args) != 2 { logging.Log.WriteErrorf("Did not provide configuration") return } - conf, err := conf.ParseDaemonConfiguration(os.Args[1]) + configuration, err := conf.ParseDaemonConfiguration(os.Args[1]) if err != nil { logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error()) return } + logging.SetLogger(logging.NewLogrusLogger(configuration.LogLevel)) + client, err := wgctrl.New() if err != nil { @@ -34,7 +37,7 @@ func main() { return } - if conf.Profile { + if configuration.Profile { go func() { http.ListenAndServe("localhost:6060", nil) }() @@ -45,7 +48,7 @@ func main() { var syncProvider sync.SyncServiceImpl ctrlServerParams := ctrlserver.NewCtrlServerParams{ - Conf: conf, + Conf: configuration, CtrlProvider: &robinRpc, SyncProvider: &syncProvider, Client: client, diff --git a/pkg/automerge/automerge_test.go b/pkg/automerge/automerge_test.go index 6d88b3d..9e6f322 100644 --- a/pkg/automerge/automerge_test.go +++ b/pkg/automerge/automerge_test.go @@ -83,7 +83,6 @@ func TestAddNodeAddRoute(t *testing.T) { testParams.manager.AddNode(testNode) testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{ Destination: destination, - HopCount: 0, Path: make([]string, 0), }) updatedNode, err := testParams.manager.GetNode(pubKey.String()) @@ -297,7 +296,6 @@ func TestAddRoutesNodeDoesNotExist(t *testing.T) { err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{ Destination: destination, - HopCount: 0, Path: make([]string, 0), }) diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index 19d65f6..88c284f 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -63,7 +63,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string } else { ipFunc := lib.GetPublicIP - if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { + if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY { ipFunc = lib.GetOutboundIP } diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index e0bf283..0d3f5e9 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -8,14 +8,7 @@ import ( "gopkg.in/yaml.v3" ) -type WgMeshConfigurationError struct { - msg string -} - -func (m *WgMeshConfigurationError) Error() string { - return m.msg -} - +// NodeType types of the node either peer or client type NodeType string const ( @@ -23,11 +16,23 @@ const ( CLIENT_ROLE NodeType = "client" ) +// IPDiscovery: what IPDiscovery service to use type IPDiscovery string const ( + // Public IP use an IP service to discover your IP PUBLIC_IP_DISCOVERY IPDiscovery = "public" - DNS_IP_DISCOVERY IPDiscovery = "dns" + // Outgonig: Use your labelled packet IP + OUTGOING_IP_DISCOVERY IPDiscovery = "outgoing" +) + +// Loglevel: what log level to use either error info or warning +type LogLevel string + +const ( + ERROR LogLevel = "error" + WARNING LogLevel = "warning" + INFO LogLevel = "info" ) // WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can @@ -35,7 +40,7 @@ const ( type WgConfiguration struct { // IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public // service for IPDiscoverability - IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` + IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=outgoing"` // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"` // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route @@ -76,24 +81,26 @@ type DaemonConfiguration struct { Profile bool `yaml:"profile"` // StubWg whether or not to stub the WireGuard types StubWg bool `yaml:"stubWg"` - // SyncTime specifies how long the minimum time should be between synchronisation - SyncTime int `yaml:"syncTime" validate:"required,gte=1"` - // PullTime specifies the interval between checking for configuration changes - PullTime int `yaml:"pullTime" validate:"gte=0"` - // HeartBeat: number of seconds before the leader of the mesh sends an update to + // SyncInterval specifies how long the minimum time should be between synchronisation + SyncInterval int `yaml:"syncInterval" validate:"required,gte=1"` + // PullInterval specifies the interval between checking for configuration changes + PullInterval int `yaml:"pullInterval" validate:"gte=0"` + // Heartbeat: number of seconds before the leader of the mesh sends an update to // send to every member in the mesh - HeartBeat int `yaml:"heartBeatTime" validate:"required,gte=1"` + Heartbeat int `yaml:"heartbeatInterval" validate:"required,gte=1"` // ClusterSize specifies how many neighbours you should synchronise with per round ClusterSize int `yaml:"clusterSize" validate:"gte=1"` // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"` - // BranchRate specifies the number of nodes to synchronise with when a node has + // Branch specifies the number of nodes to synchronise with when a node has // new changes to send to the mesh - BranchRate int `yaml:"branchRate" validate:"required,gte=1"` + Branch int `yaml:"branch" validate:"required,gte=1"` // InfectionCount: number of time to sync before an update can no longer be 'caught' InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"` // BaseConfiguration base WireGuard configuration to use, this is used when none is provided BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"` + // LogLevel specifies the log level to output, defaults is warning + LogLevel LogLevel `yaml:"logLevel" validate:"eq=info|eq=warning|eq=error"` } // ValdiateMeshConfiguration: validates the mesh configuration @@ -121,9 +128,18 @@ func ValidateMeshConfiguration(conf *WgConfiguration) error { } // ValidateDaemonConfiguration: validates the dameon configuration that is used. -func ValidateDaemonConfiguration(c *DaemonConfiguration) error { +func ValidateDaemonConfiguration(conf *DaemonConfiguration) error { + if conf.BaseConfiguration.KeepAliveWg == nil { + var keepAlive int = 0 + conf.BaseConfiguration.KeepAliveWg = &keepAlive + } + + if conf.LogLevel == "" { + conf.LogLevel = WARNING + } + validate := validator.New(validator.WithRequiredStructEnabled()) - err := validate.Struct(c) + err := validate.Struct(conf) return err } @@ -143,11 +159,6 @@ func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) { return nil, err } - if conf.BaseConfiguration.KeepAliveWg == nil { - var keepAlive int = 0 - conf.BaseConfiguration.KeepAliveWg = &keepAlive - } - return &conf, ValidateDaemonConfiguration(&conf) } diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go index 4189470..e7066f8 100644 --- a/pkg/conf/conf_test.go +++ b/pkg/conf/conf_test.go @@ -21,12 +21,12 @@ func getExampleConfiguration() *DaemonConfiguration { Timeout: 5, Profile: false, StubWg: false, - SyncTime: 2, - HeartBeat: 2, + SyncInterval: 2, + Heartbeat: 2, ClusterSize: 64, InterClusterChance: 0.15, - BranchRate: 3, - PullTime: 0, + Branch: 3, + PullInterval: 0, InfectionCount: 2, BaseConfiguration: WgConfiguration{ IPDiscovery: &discovery, @@ -154,7 +154,7 @@ func TestRoleTypeNotSpecified(t *testing.T) { func TestBranchRateZero(t *testing.T) { conf := getExampleConfiguration() - conf.BranchRate = 0 + conf.Branch = 0 err := ValidateDaemonConfiguration(conf) @@ -165,7 +165,7 @@ func TestBranchRateZero(t *testing.T) { func TestsyncTimeZero(t *testing.T) { conf := getExampleConfiguration() - conf.SyncTime = 0 + conf.SyncInterval = 0 err := ValidateDaemonConfiguration(conf) @@ -176,7 +176,7 @@ func TestsyncTimeZero(t *testing.T) { func TestKeepAliveTimeZero(t *testing.T) { conf := getExampleConfiguration() - conf.HeartBeat = 0 + conf.Heartbeat = 0 err := ValidateDaemonConfiguration(conf) if err == nil { @@ -218,7 +218,7 @@ func TestInfectionCountOne(t *testing.T) { func TestPullTimeNegative(t *testing.T) { conf := getExampleConfiguration() - conf.PullTime = -1 + conf.PullInterval = -1 err := ValidateDaemonConfiguration(conf) diff --git a/pkg/crdt/datastore.go b/pkg/crdt/datastore.go index ed4fea0..a6c871c 100644 --- a/pkg/crdt/datastore.go +++ b/pkg/crdt/datastore.go @@ -268,7 +268,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { peerToUpdate := peers[0] - if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.HeartBeat) { + if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.Heartbeat) { m.store.Mark(peerToUpdate) if len(peers) < 2 { @@ -341,6 +341,7 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro for _, route := range routes { changes = true + logging.Log.WriteInfof("deleting: %s", route.GetDestination().String()) delete(node.Routes, route.GetDestination().String()) } diff --git a/pkg/crdt/datastore_test.go b/pkg/crdt/datastore_test.go index 6d9459b..5927349 100644 --- a/pkg/crdt/datastore_test.go +++ b/pkg/crdt/datastore_test.go @@ -21,7 +21,7 @@ func setUpTests() *TestParams { advertiseRoutes := false advertiseDefaultRoute := false role := conf.PEER_ROLE - discovery := conf.DNS_IP_DISCOVERY + discovery := conf.OUTGOING_IP_DISCOVERY factory := &TwoPhaseMapFactory{ Config: &conf.DaemonConfiguration{ @@ -32,11 +32,11 @@ func setUpTests() *TestParams { GrpcPort: 0, Timeout: 20, Profile: false, - SyncTime: 2, - HeartBeat: 10, + SyncInterval: 2, + Heartbeat: 10, ClusterSize: 32, InterClusterChance: 0.15, - BranchRate: 3, + Branch: 3, InfectionCount: 3, BaseConfiguration: conf.WgConfiguration{ IPDiscovery: &discovery, @@ -215,7 +215,6 @@ func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) { testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{ Destination: destination, - HopCount: 0, Path: make([]string, 0), }) @@ -238,7 +237,6 @@ func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) { _, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64") route := &mesh.RouteStub{ Destination: destination, - HopCount: 0, Path: make([]string, 0), } diff --git a/pkg/crdt/factory.go b/pkg/crdt/factory.go index 518bbc4..9e9f315 100644 --- a/pkg/crdt/factory.go +++ b/pkg/crdt/factory.go @@ -27,7 +27,7 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) h := fnv.New64a() h.Write([]byte(s)) return h.Sum64() - }, uint64(3*f.Config.HeartBeat)), + }, uint64(3*f.Config.Heartbeat)), }, nil } @@ -71,7 +71,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string } else { ipFunc := lib.GetPublicIP - if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { + if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY { ipFunc = lib.GetOutboundIP } diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index cdd25bb..440949f 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -113,7 +113,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { heartbeatTimer := lib.NewTimer(func() error { logging.Log.WriteInfof("checking heartbeat") return ctrlServer.MeshManager.UpdateTimeStamp() - }, params.Conf.HeartBeat) + }, params.Conf.Heartbeat) ctrlServer.timers = append(ctrlServer.timers, syncTimer, heartbeatTimer) diff --git a/pkg/grpc/ctrlserver.pb.go b/pkg/grpc/ctrlserver.pb.go new file mode 100644 index 0000000..2c87b66 --- /dev/null +++ b/pkg/grpc/ctrlserver.pb.go @@ -0,0 +1,212 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.12 +// source: pkg/grpc/ctrlserver.proto + +package rpc + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type GetMeshRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"` +} + +func (x *GetMeshRequest) Reset() { + *x = GetMeshRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetMeshRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetMeshRequest) ProtoMessage() {} + +func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. +func (*GetMeshRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{0} +} + +func (x *GetMeshRequest) GetMeshId() string { + if x != nil { + return x.MeshId + } + return "" +} + +type GetMeshReply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Mesh []byte `protobuf:"bytes,1,opt,name=mesh,proto3" json:"mesh,omitempty"` +} + +func (x *GetMeshReply) Reset() { + *x = GetMeshReply{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetMeshReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetMeshReply) ProtoMessage() {} + +func (x *GetMeshReply) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. +func (*GetMeshReply) Descriptor() ([]byte, []int) { + return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{1} +} + +func (x *GetMeshReply) GetMesh() []byte { + if x != nil { + return x.Mesh + } + return nil +} + +var File_pkg_grpc_ctrlserver_proto protoreflect.FileDescriptor + +var file_pkg_grpc_ctrlserver_proto_rawDesc = []byte{ + 0x0a, 0x19, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, + 0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, + 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, + 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, + 0x65, 0x73, 0x68, 0x32, 0x4f, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, + 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, + 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, + 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_grpc_ctrlserver_proto_rawDescOnce sync.Once + file_pkg_grpc_ctrlserver_proto_rawDescData = file_pkg_grpc_ctrlserver_proto_rawDesc +) + +func file_pkg_grpc_ctrlserver_proto_rawDescGZIP() []byte { + file_pkg_grpc_ctrlserver_proto_rawDescOnce.Do(func() { + file_pkg_grpc_ctrlserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_proto_rawDescData) + }) + return file_pkg_grpc_ctrlserver_proto_rawDescData +} + +var file_pkg_grpc_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_pkg_grpc_ctrlserver_proto_goTypes = []interface{}{ + (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest + (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply +} +var file_pkg_grpc_ctrlserver_proto_depIdxs = []int32{ + 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest + 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pkg_grpc_ctrlserver_proto_init() } +func file_pkg_grpc_ctrlserver_proto_init() { + if File_pkg_grpc_ctrlserver_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_grpc_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetMeshRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetMeshReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_grpc_ctrlserver_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_grpc_ctrlserver_proto_goTypes, + DependencyIndexes: file_pkg_grpc_ctrlserver_proto_depIdxs, + MessageInfos: file_pkg_grpc_ctrlserver_proto_msgTypes, + }.Build() + File_pkg_grpc_ctrlserver_proto = out.File + file_pkg_grpc_ctrlserver_proto_rawDesc = nil + file_pkg_grpc_ctrlserver_proto_goTypes = nil + file_pkg_grpc_ctrlserver_proto_depIdxs = nil +} diff --git a/pkg/grpc/ctrlserver/ctrlserver.proto b/pkg/grpc/ctrlserver.proto similarity index 100% rename from pkg/grpc/ctrlserver/ctrlserver.proto rename to pkg/grpc/ctrlserver.proto diff --git a/pkg/grpc/ctrlserver_grpc.pb.go b/pkg/grpc/ctrlserver_grpc.pb.go new file mode 100644 index 0000000..ff84ba7 --- /dev/null +++ b/pkg/grpc/ctrlserver_grpc.pb.go @@ -0,0 +1,105 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.21.12 +// source: pkg/grpc/ctrlserver.proto + +package rpc + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// MeshCtrlServerClient is the client API for MeshCtrlServer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type MeshCtrlServerClient interface { + GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) +} + +type meshCtrlServerClient struct { + cc grpc.ClientConnInterface +} + +func NewMeshCtrlServerClient(cc grpc.ClientConnInterface) MeshCtrlServerClient { + return &meshCtrlServerClient{cc} +} + +func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) { + out := new(GetMeshReply) + err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/GetMesh", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// MeshCtrlServerServer is the server API for MeshCtrlServer service. +// All implementations must embed UnimplementedMeshCtrlServerServer +// for forward compatibility +type MeshCtrlServerServer interface { + GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) + mustEmbedUnimplementedMeshCtrlServerServer() +} + +// UnimplementedMeshCtrlServerServer must be embedded to have forward compatible implementations. +type UnimplementedMeshCtrlServerServer struct { +} + +func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") +} +func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} + +// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to MeshCtrlServerServer will +// result in compilation errors. +type UnsafeMeshCtrlServerServer interface { + mustEmbedUnimplementedMeshCtrlServerServer() +} + +func RegisterMeshCtrlServerServer(s grpc.ServiceRegistrar, srv MeshCtrlServerServer) { + s.RegisterService(&MeshCtrlServer_ServiceDesc, srv) +} + +func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetMeshRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MeshCtrlServerServer).GetMesh(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/rpctypes.MeshCtrlServer/GetMesh", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MeshCtrlServerServer).GetMesh(ctx, req.(*GetMeshRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "rpctypes.MeshCtrlServer", + HandlerType: (*MeshCtrlServerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetMesh", + Handler: _MeshCtrlServer_GetMesh_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "pkg/grpc/ctrlserver.proto", +} diff --git a/pkg/grpc/syncservice.pb.go b/pkg/grpc/syncservice.pb.go new file mode 100644 index 0000000..77f8317 --- /dev/null +++ b/pkg/grpc/syncservice.pb.go @@ -0,0 +1,233 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.12 +// source: pkg/grpc/syncservice.proto + +package rpc + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type SyncMeshRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"` + Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"` +} + +func (x *SyncMeshRequest) Reset() { + *x = SyncMeshRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_syncservice_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMeshRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMeshRequest) ProtoMessage() {} + +func (x *SyncMeshRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_syncservice_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMeshRequest.ProtoReflect.Descriptor instead. +func (*SyncMeshRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{0} +} + +func (x *SyncMeshRequest) GetMeshId() string { + if x != nil { + return x.MeshId + } + return "" +} + +func (x *SyncMeshRequest) GetChanges() []byte { + if x != nil { + return x.Changes + } + return nil +} + +type SyncMeshReply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"` +} + +func (x *SyncMeshReply) Reset() { + *x = SyncMeshReply{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_syncservice_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMeshReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMeshReply) ProtoMessage() {} + +func (x *SyncMeshReply) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_syncservice_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMeshReply.ProtoReflect.Descriptor instead. +func (*SyncMeshReply) Descriptor() ([]byte, []int) { + return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{1} +} + +func (x *SyncMeshReply) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *SyncMeshReply) GetChanges() []byte { + if x != nil { + return x.Changes + } + return nil +} + +var File_pkg_grpc_syncservice_proto protoreflect.FileDescriptor + +var file_pkg_grpc_syncservice_proto_rawDesc = []byte{ + 0x0a, 0x1a, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x73, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x73, 0x79, + 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x43, 0x0a, 0x0f, 0x53, 0x79, 0x6e, + 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, + 0x73, 0x68, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x22, 0x43, + 0x0a, 0x0d, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, + 0x6e, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, + 0x67, 0x65, 0x73, 0x32, 0x59, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x4a, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1c, + 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, + 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x73, + 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x09, + 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, +} + +var ( + file_pkg_grpc_syncservice_proto_rawDescOnce sync.Once + file_pkg_grpc_syncservice_proto_rawDescData = file_pkg_grpc_syncservice_proto_rawDesc +) + +func file_pkg_grpc_syncservice_proto_rawDescGZIP() []byte { + file_pkg_grpc_syncservice_proto_rawDescOnce.Do(func() { + file_pkg_grpc_syncservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_syncservice_proto_rawDescData) + }) + return file_pkg_grpc_syncservice_proto_rawDescData +} + +var file_pkg_grpc_syncservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_pkg_grpc_syncservice_proto_goTypes = []interface{}{ + (*SyncMeshRequest)(nil), // 0: syncservice.SyncMeshRequest + (*SyncMeshReply)(nil), // 1: syncservice.SyncMeshReply +} +var file_pkg_grpc_syncservice_proto_depIdxs = []int32{ + 0, // 0: syncservice.SyncService.SyncMesh:input_type -> syncservice.SyncMeshRequest + 1, // 1: syncservice.SyncService.SyncMesh:output_type -> syncservice.SyncMeshReply + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pkg_grpc_syncservice_proto_init() } +func file_pkg_grpc_syncservice_proto_init() { + if File_pkg_grpc_syncservice_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_grpc_syncservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMeshRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_syncservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMeshReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_grpc_syncservice_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_grpc_syncservice_proto_goTypes, + DependencyIndexes: file_pkg_grpc_syncservice_proto_depIdxs, + MessageInfos: file_pkg_grpc_syncservice_proto_msgTypes, + }.Build() + File_pkg_grpc_syncservice_proto = out.File + file_pkg_grpc_syncservice_proto_rawDesc = nil + file_pkg_grpc_syncservice_proto_goTypes = nil + file_pkg_grpc_syncservice_proto_depIdxs = nil +} diff --git a/pkg/grpc/ctrlserver/syncservice.proto b/pkg/grpc/syncservice.proto similarity index 67% rename from pkg/grpc/ctrlserver/syncservice.proto rename to pkg/grpc/syncservice.proto index b08faf1..3081949 100644 --- a/pkg/grpc/ctrlserver/syncservice.proto +++ b/pkg/grpc/syncservice.proto @@ -4,18 +4,9 @@ package syncservice; option go_package = "pkg/rpc"; service SyncService { - rpc GetConf(GetConfRequest) returns (GetConfReply) {} rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {} } -message GetConfRequest { - string meshId = 1; -} - -message GetConfReply { - bytes mesh = 1; -} - message SyncMeshRequest { string meshId = 1; bytes changes = 2; diff --git a/pkg/grpc/syncservice_grpc.pb.go b/pkg/grpc/syncservice_grpc.pb.go new file mode 100644 index 0000000..5fc27a3 --- /dev/null +++ b/pkg/grpc/syncservice_grpc.pb.go @@ -0,0 +1,137 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.21.12 +// source: pkg/grpc/syncservice.proto + +package rpc + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// SyncServiceClient is the client API for SyncService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type SyncServiceClient interface { + SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) +} + +type syncServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewSyncServiceClient(cc grpc.ClientConnInterface) SyncServiceClient { + return &syncServiceClient{cc} +} + +func (c *syncServiceClient) SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) { + stream, err := c.cc.NewStream(ctx, &SyncService_ServiceDesc.Streams[0], "/syncservice.SyncService/SyncMesh", opts...) + if err != nil { + return nil, err + } + x := &syncServiceSyncMeshClient{stream} + return x, nil +} + +type SyncService_SyncMeshClient interface { + Send(*SyncMeshRequest) error + Recv() (*SyncMeshReply, error) + grpc.ClientStream +} + +type syncServiceSyncMeshClient struct { + grpc.ClientStream +} + +func (x *syncServiceSyncMeshClient) Send(m *SyncMeshRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *syncServiceSyncMeshClient) Recv() (*SyncMeshReply, error) { + m := new(SyncMeshReply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// SyncServiceServer is the server API for SyncService service. +// All implementations must embed UnimplementedSyncServiceServer +// for forward compatibility +type SyncServiceServer interface { + SyncMesh(SyncService_SyncMeshServer) error + mustEmbedUnimplementedSyncServiceServer() +} + +// UnimplementedSyncServiceServer must be embedded to have forward compatible implementations. +type UnimplementedSyncServiceServer struct { +} + +func (UnimplementedSyncServiceServer) SyncMesh(SyncService_SyncMeshServer) error { + return status.Errorf(codes.Unimplemented, "method SyncMesh not implemented") +} +func (UnimplementedSyncServiceServer) mustEmbedUnimplementedSyncServiceServer() {} + +// UnsafeSyncServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to SyncServiceServer will +// result in compilation errors. +type UnsafeSyncServiceServer interface { + mustEmbedUnimplementedSyncServiceServer() +} + +func RegisterSyncServiceServer(s grpc.ServiceRegistrar, srv SyncServiceServer) { + s.RegisterService(&SyncService_ServiceDesc, srv) +} + +func _SyncService_SyncMesh_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(SyncServiceServer).SyncMesh(&syncServiceSyncMeshServer{stream}) +} + +type SyncService_SyncMeshServer interface { + Send(*SyncMeshReply) error + Recv() (*SyncMeshRequest, error) + grpc.ServerStream +} + +type syncServiceSyncMeshServer struct { + grpc.ServerStream +} + +func (x *syncServiceSyncMeshServer) Send(m *SyncMeshReply) error { + return x.ServerStream.SendMsg(m) +} + +func (x *syncServiceSyncMeshServer) Recv() (*SyncMeshRequest, error) { + m := new(SyncMeshRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// SyncService_ServiceDesc is the grpc.ServiceDesc for SyncService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var SyncService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "syncservice.SyncService", + HandlerType: (*SyncServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "SyncMesh", + Handler: _SyncService_SyncMesh_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "pkg/grpc/syncservice.proto", +} diff --git a/pkg/log/log.go b/pkg/log/log.go index 71ff0f1..d0d8002 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -6,6 +6,7 @@ import ( "os" "github.com/sirupsen/logrus" + "github.com/tim-beatham/smegmesh/pkg/conf" ) var ( @@ -39,17 +40,29 @@ func (l *LogrusLogger) Writer() io.Writer { return l.logger.Writer() } -func NewLogrusLogger() *LogrusLogger { +func NewLogrusLogger(confLevel conf.LogLevel) *LogrusLogger { + + var level logrus.Level + + switch confLevel { + case conf.ERROR: + level = logrus.ErrorLevel + case conf.WARNING: + level = logrus.WarnLevel + case conf.INFO: + level = logrus.InfoLevel + } + logger := logrus.New() logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetOutput(os.Stdout) - logger.SetLevel(logrus.InfoLevel) + logger.SetLevel(level) return &LogrusLogger{logger: logger} } func init() { - SetLogger(NewLogrusLogger()) + SetLogger(NewLogrusLogger(conf.INFO)) } func SetLogger(l Logger) { diff --git a/pkg/mesh/config.go b/pkg/mesh/config.go index 2d42a7d..9e60bdf 100644 --- a/pkg/mesh/config.go +++ b/pkg/mesh/config.go @@ -175,7 +175,6 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][ rn.gateway = peerPub.String() rn.route = &RouteStub{ Destination: rn.route.GetDestination(), - HopCount: rn.route.GetHopCount() + 1, Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()), } } @@ -283,10 +282,14 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes installedRoutes := make([]lib.Route, 0) for _, route := range peerCfgs[0].AllowedIPs { - installedRoutes = append(installedRoutes, lib.Route{ - Gateway: peer.GetWgHost().IP, - Destination: route, - }) + // Don't install routes that we are directly apart + // Dont install default route wgctrl handles this for us + if !meshNet.Contains(route.IP) { + installedRoutes = append(installedRoutes, lib.Route{ + Gateway: peer.GetWgHost().IP, + Destination: route, + }) + } } cfg := wgtypes.Config{ @@ -306,9 +309,8 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes ula := &ip.ULABuilder{} ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) - _, defaultRoute, _ := net.ParseCIDR("::/0") - - if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) { + // Check there is no overlap in network and its not the default route + if !ipNet.Contains(route.IP) { routes = append(routes, lib.Route{ Gateway: node.GetWgHost().IP, Destination: route, diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index bcdc841..b377387 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -347,7 +347,7 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { return nil } -// LeaveMesh: leaves the mesh network +// LeaveMesh: leaves the mesh network and force a synchronsiation func (s *MeshManagerImpl) LeaveMesh(meshId string) error { mesh := s.GetMesh(meshId) @@ -468,6 +468,9 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { meshes := make(map[string]MeshProvider) + // GetMesh: copies the map of meshes to a new map + // to prevent a whole range of concurrency issues + // due to iteration and modification s.meshLock.RLock() for id, mesh := range s.meshes { diff --git a/pkg/mesh/manager_test.go b/pkg/mesh/manager_test.go index 4a80f17..1d35e03 100644 --- a/pkg/mesh/manager_test.go +++ b/pkg/mesh/manager_test.go @@ -24,11 +24,11 @@ func getMeshConfiguration() *conf.DaemonConfiguration { Timeout: 5, Profile: false, StubWg: true, - SyncTime: 2, - HeartBeat: 60, + SyncInterval: 2, + Heartbeat: 60, ClusterSize: 64, InterClusterChance: 0.15, - BranchRate: 3, + Branch: 3, InfectionCount: 3, BaseConfiguration: conf.WgConfiguration{ IPDiscovery: &ipDiscovery, diff --git a/pkg/mesh/route.go b/pkg/mesh/route.go index f6eab4b..b226e8b 100644 --- a/pkg/mesh/route.go +++ b/pkg/mesh/route.go @@ -46,7 +46,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error { defaultRoute := &RouteStub{ Destination: ipv6Default, - HopCount: 0, Path: []string{mesh1.GetMeshId()}, } @@ -75,7 +74,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error { routeValues = append(routeValues, &RouteStub{ Destination: mesh1IpNet, - HopCount: 0, Path: []string{mesh1.GetMeshId()}, }) @@ -106,15 +104,12 @@ func (r *RouteManagerImpl) UpdateRoutes() error { } toRemove := make([]Route, 0) - prevRoutes, err := mesh.GetRoutes(NodeID(self)) - if err != nil { - return err - } + prevRoutes := self.GetRoutes() for _, route := range prevRoutes { if !lib.Contains(meshRoutes, func(r Route) bool { - return RouteEquals(r, route) + return RouteEqual(r, route) }) { toRemove = append(toRemove, route) } diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index 66cb5b7..c5df71d 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -21,12 +21,6 @@ type Route interface { } func RouteEqual(r1 Route, r2 Route) bool { - return r1.GetDestination().IP.Equal(r2.GetDestination().IP) && - r1.GetHopCount() == r2.GetHopCount() && - slices.Equal(r1.GetPath(), r2.GetPath()) -} - -func RouteEquals(r1, r2 Route) bool { return r1.GetDestination().String() == r2.GetDestination().String() && r1.GetHopCount() == r2.GetHopCount() && slices.Equal(r1.GetPath(), r2.GetPath()) @@ -34,7 +28,6 @@ func RouteEquals(r1, r2 Route) bool { type RouteStub struct { Destination *net.IPNet - HopCount int Path []string } @@ -43,7 +36,7 @@ func (r *RouteStub) GetDestination() *net.IPNet { } func (r *RouteStub) GetHopCount() int { - return r.HopCount + return len(r.Path) } func (r *RouteStub) GetPath() []string { diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index 8a31c7e..c3cb592 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "time" "github.com/tim-beatham/smegmesh/pkg/conf" @@ -141,7 +140,7 @@ func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error { return err } - *reply = strconv.FormatBool(true) + *reply = fmt.Sprintf("Successfully Joined: %s", args.MeshId) return nil } @@ -235,7 +234,7 @@ func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error return err } - *reply = "success" + *reply = fmt.Sprintf("Set service %s in %s to %s", service.Service, service.MeshId, service.Value) return nil } @@ -247,7 +246,7 @@ func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) return err } - *reply = "success" + *reply = fmt.Sprintf("Removed service %s from %s", service.Service, service.MeshId) return nil } diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 6d28da6..0b73a0b 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -55,7 +55,7 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) { logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId()) // If not synchronised in certain time pull from random neighbour - if s.configuration.PullTime != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullTime) { + if s.configuration.PullInterval != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullInterval) { return s.Pull(self, correspondingMesh) } @@ -88,7 +88,7 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) { gossipNodes = neighbours[:redundancyLength] } else { neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) - gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.BranchRate) + gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.Branch) if len(nodeNames) > s.configuration.ClusterSize && rand.Float64() < s.configuration.InterClusterChance { gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) @@ -97,26 +97,37 @@ func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) { var succeeded bool = false - // Do this synchronously to conserve bandwidth - for _, node := range gossipNodes { - correspondingPeer, err := correspondingMesh.GetNode(node) + var wait sync.WaitGroup - if correspondingPeer == nil || err != nil { - logging.Log.WriteErrorf("node %s does not exist", node) - continue + for index, node := range gossipNodes { + wait.Add(1) + + syncNode := func(i int) { + correspondingPeer, err := correspondingMesh.GetNode(node) + + defer wait.Done() + + if correspondingPeer == nil || err != nil { + logging.Log.WriteErrorf("node %s does not exist", node) + return + } + + err = s.requester.SyncMesh(correspondingMesh, correspondingPeer) + + if err == nil || err == io.EOF { + succeeded = true + } + + if err != nil { + logging.Log.WriteErrorf(err.Error()) + } } - err = s.requester.SyncMesh(correspondingMesh.GetMeshId(), correspondingPeer) - - if err == nil || err == io.EOF { - succeeded = true - } - - if err != nil { - logging.Log.WriteErrorf(err.Error()) - } + go syncNode(index) } + wait.Wait() + s.syncCount++ logging.Log.WriteInfof("sync time: %v", time.Since(before)) logging.Log.WriteInfof("number of syncs: %d", s.syncCount) @@ -158,7 +169,7 @@ func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) (bool, err return false, fmt.Errorf("node %s does not exist in the mesh", neighbour[0]) } - err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode) + err = s.requester.SyncMesh(mesh, pullNode) if err == nil || err == io.EOF { s.lastSync[mesh.GetMeshId()] = time.Now().Unix() @@ -180,7 +191,7 @@ func (s *SyncerImpl) SyncMeshes() error { s.lastPollLock.Lock() meshesToSync := lib.Filter(lib.MapValues(meshes), func(mesh mesh.MeshProvider) bool { - return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncTime) + return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncInterval) }) s.lastPollLock.Unlock() diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index c3cad2a..ea16e16 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -10,7 +10,7 @@ import ( // SyncErrorHandler: Handles errors when attempting to sync type SyncErrorHandler interface { - Handle(meshId string, endpoint string, err error) bool + Handle(mesh mesh.MeshProvider, endpoint string, err error) bool } // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler @@ -19,8 +19,7 @@ type SyncErrorHandlerImpl struct { connManager conn.ConnectionManager } -func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { - mesh := s.meshManager.GetMesh(meshId) +func (s *SyncErrorHandlerImpl) handleFailed(mesh mesh.MeshProvider, nodeId string) bool { mesh.Mark(nodeId) node, err := mesh.GetNode(nodeId) @@ -30,13 +29,7 @@ func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { return true } -func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId string) bool { - mesh := s.meshManager.GetMesh(meshId) - - if mesh == nil { - return true - } - +func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(mesh mesh.MeshProvider, nodeId string) bool { node, err := mesh.GetNode(nodeId) if err != nil { @@ -47,16 +40,16 @@ func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId stri return true } -func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool { +func (s *SyncErrorHandlerImpl) Handle(mesh mesh.MeshProvider, nodeId string, err error) bool { errStatus, _ := status.FromError(err) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) switch errStatus.Code() { case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound: - return s.handleFailed(meshId, nodeId) + return s.handleFailed(mesh, nodeId) case codes.DeadlineExceeded: - return s.handleDeadlineExceeded(meshId, nodeId) + return s.handleDeadlineExceeded(mesh, nodeId) } return false diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index a017692..4f9cfe9 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -2,7 +2,6 @@ package sync import ( "context" - "errors" "io" "time" @@ -15,8 +14,7 @@ import ( // SyncRequester: coordinates the syncing of meshes type SyncRequester interface { - GetMesh(meshId string, ifName string, port int, endPoint string) error - SyncMesh(meshid string, meshNode mesh.MeshNode) error + SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error } type SyncRequesterImpl struct { @@ -26,42 +24,9 @@ type SyncRequesterImpl struct { errorHdlr SyncErrorHandler } -// GetMesh: Retrieves the local state of the mesh at the endpoint -func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error { - peerConnection, err := s.connectionManager.GetConnection(endPoint) - - if err != nil { - return err - } - - client, err := peerConnection.GetClient() - - if err != nil { - return err - } - - c := rpc.NewSyncServiceClient(client) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId}) - - if err != nil { - return err - } - - err = s.manager.AddMesh(&mesh.AddMeshParams{ - MeshId: meshId, - WgPort: port, - MeshBytes: reply.Mesh, - }) - return err -} - // handleErr: handleGrpc errors -func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error { - ok := s.errorHdlr.Handle(meshId, pubKey, err) +func (s *SyncRequesterImpl) handleErr(mesh mesh.MeshProvider, pubKey string, err error) error { + ok := s.errorHdlr.Handle(mesh, pubKey, err) if ok { return nil @@ -70,7 +35,7 @@ func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error { } // SyncMesh: Proactively send a sync request to the other mesh -func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error { +func (s *SyncRequesterImpl) SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error { endpoint := meshNode.GetHostEndpoint() pubKey, _ := meshNode.GetPublicKey() @@ -86,15 +51,9 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro return err } - mesh := s.manager.GetMesh(meshId) - - if mesh == nil { - return errors.New("mesh does not exist") - } - c := rpc.NewSyncServiceClient(client) - syncTimeOut := float64(s.configuration.SyncTime) * float64(time.Second) + syncTimeOut := float64(s.configuration.SyncInterval) * float64(time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) defer cancel() @@ -102,10 +61,10 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro err = s.syncMesh(mesh, ctx, c) if err != nil { - s.handleErr(meshId, pubKey.String(), err) + s.handleErr(mesh, pubKey.String(), err) } - logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, meshId) + logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, mesh.GetMeshId()) return err }