diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index e00fccc..8188d37 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -39,7 +39,7 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { crdt.Services = make(map[string]string) crdt.Timestamp = time.Now().Unix() - c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) + c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) } func (c *CrdtMeshManager) isPeer(nodeId string) bool { diff --git a/pkg/automerge/factory.go b/pkg/automerge/factory.go index e23520c..dd1632c 100644 --- a/pkg/automerge/factory.go +++ b/pkg/automerge/factory.go @@ -28,8 +28,14 @@ type MeshNodeFactory struct { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { hostName := f.getAddress(params) + grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) + + if f.Config.Role == conf.CLIENT_ROLE { + grpcEndpoint = "-" + } + return &MeshNodeCrdt{ - HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort), + HostEndpoint: grpcEndpoint, PublicKey: params.PublicKey.String(), WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), @@ -38,7 +44,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod Routes: map[string]interface{}{}, Description: "", Alias: "", - Type: string(params.Role), + Type: string(f.Config.Role), } } @@ -51,7 +57,13 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string } else if len(f.Config.Endpoint) != 0 { hostName = f.Config.Endpoint } else { - ip, err := lib.GetPublicIP() + ipFunc := lib.GetPublicIP + + if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY { + ipFunc = lib.GetOutboundIP + } + + ip, err := ipFunc() if err != nil { return "" diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 1c3f4d6..4afad08 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -23,6 +23,13 @@ const ( CLIENT_ROLE NodeType = "client" ) +type IPDiscovery string + +const ( + PUBLIC_IP_DISCOVERY = "public" + DNS_IP_DISCOVERY = "dns" +) + type WgMeshConfiguration struct { // CertificatePath is the path to the certificate to use in mTLS CertificatePath string `yaml:"certificatePath"` @@ -35,6 +42,9 @@ type WgMeshConfiguration struct { SkipCertVerification bool `yaml:"skipCertVerification"` // Port to run the GrpcServer on GrpcPort string `yaml:"gRPCPort"` + // IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or + // use public IP discovery library + IPDiscovery IPDiscovery `yaml:"ipDiscovery"` // AdvertiseRoutes advertises other meshes if the node is in multiple meshes AdvertiseRoutes bool `yaml:"advertiseRoutes"` // Endpoint is the IP in which this computer is publicly reachable. @@ -151,6 +161,10 @@ func ValidateConfiguration(c *WgMeshConfiguration) error { c.Role = PEER_ROLE } + if c.IPDiscovery == "" { + c.IPDiscovery = PUBLIC_IP_DISCOVERY + } + return nil } diff --git a/pkg/grpc/ctrlserver/ctrlserver.proto b/pkg/grpc/ctrlserver/ctrlserver.proto index e71ff0e..51a9366 100644 --- a/pkg/grpc/ctrlserver/ctrlserver.proto +++ b/pkg/grpc/ctrlserver/ctrlserver.proto @@ -4,13 +4,13 @@ package rpctypes; option go_package = "pkg/rpc"; service MeshCtrlServer { - rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {} + rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {} } -message JoinMeshRequest { - string meshId = 2; +message GetMeshRequest { + string meshId = 1; } -message JoinMeshReply { - bool success = 1; +message GetMeshReply { + bytes mesh = 1; } \ No newline at end of file diff --git a/pkg/lib/ip.go b/pkg/lib/ip.go index 0bd3740..84e6634 100644 --- a/pkg/lib/ip.go +++ b/pkg/lib/ip.go @@ -9,14 +9,14 @@ import ( ) // GetOutboundIP: gets the oubound IP of this packet -func GetOutboundIP() net.IP { +func GetOutboundIP() (net.IP, error) { conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil { log.Fatal(err) } defer conn.Close() localAddr := conn.LocalAddr().(*net.UDPAddr) - return localAddr.IP + return localAddr.IP, nil } const IP_SERVICE = "https://api.ipify.org?format=json" diff --git a/pkg/mesh/manager.go b/pkg/mesh/manager.go index 70b1e05..961ecb0 100644 --- a/pkg/mesh/manager.go +++ b/pkg/mesh/manager.go @@ -57,7 +57,7 @@ type MeshManagerImpl struct { // RemoveService implements MeshManager. func (m *MeshManagerImpl) RemoveService(service string) error { for _, mesh := range m.Meshes { - err := mesh.RemoveService(m.HostParameters.HostEndpoint, service) + err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service) if err != nil { return err @@ -70,7 +70,7 @@ func (m *MeshManagerImpl) RemoveService(service string) error { // SetService implements MeshManager. func (m *MeshManagerImpl) SetService(service string, value string) error { for _, mesh := range m.Meshes { - err := mesh.AddService(m.HostParameters.HostEndpoint, service, value) + err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value) if err != nil { return err @@ -125,7 +125,7 @@ func (m *MeshManagerImpl) CreateMesh(port int) (string, error) { } if !m.conf.StubWg { - ifName, err = m.interfaceManipulator.CreateInterface(port) + ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey) if err != nil { return "", fmt.Errorf("error creating mesh: %w", err) @@ -160,7 +160,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { var err error if !m.conf.StubWg { - ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort) + ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey) if err != nil { return err @@ -283,7 +283,6 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { NodeIP: nodeIP, WgPort: params.WgPort, Endpoint: params.Endpoint, - Role: s.conf.Role, }) if !s.conf.StubWg { @@ -339,8 +338,8 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { return nil, fmt.Errorf("mesh %s does not exist", meshId) } - logging.Log.WriteInfof(s.HostParameters.HostEndpoint) - node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint) + logging.Log.WriteInfof(s.HostParameters.GetPublicKey()) + node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey()) if err != nil { return nil, errors.New("the node doesn't exist in the mesh") @@ -365,8 +364,8 @@ func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) SetDescription(description string) error { for _, mesh := range s.Meshes { - if mesh.NodeExists(s.HostParameters.HostEndpoint) { - err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) + if mesh.NodeExists(s.HostParameters.GetPublicKey()) { + err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description) if err != nil { return err @@ -380,8 +379,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error { // SetAlias implements MeshManager. func (s *MeshManagerImpl) SetAlias(alias string) error { for _, mesh := range s.Meshes { - if mesh.NodeExists(s.HostParameters.HostEndpoint) { - err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) + if mesh.NodeExists(s.HostParameters.GetPublicKey()) { + err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias) if err != nil { return err @@ -394,8 +393,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error { // UpdateTimeStamp updates the timestamp of this node in all meshes func (s *MeshManagerImpl) UpdateTimeStamp() error { for _, mesh := range s.Meshes { - if mesh.NodeExists(s.HostParameters.HostEndpoint) { - err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) + if mesh.NodeExists(s.HostParameters.GetPublicKey()) { + err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey()) if err != nil { return err @@ -452,18 +451,11 @@ type NewMeshManagerParams struct { // Creates a new instance of a mesh manager with the given parameters func NewMeshManager(params *NewMeshManagerParams) MeshManager { - hostParams := HostParameters{} - - switch params.Conf.Endpoint { - case "": - ip, _ := lib.GetPublicIP() - hostParams.HostEndpoint = fmt.Sprintf("%s:%s", ip.String(), params.Conf.GrpcPort) - default: - hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort) + privateKey, _ := wgtypes.GeneratePrivateKey() + hostParams := HostParameters{ + PrivateKey: &privateKey, } - logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint) - m := &MeshManagerImpl{ Meshes: make(map[string]MeshProvider), HostParameters: &hostParams, diff --git a/pkg/mesh/types.go b/pkg/mesh/types.go index fe38b56..daf99b4 100644 --- a/pkg/mesh/types.go +++ b/pkg/mesh/types.go @@ -106,7 +106,12 @@ type MeshProvider interface { // HostParameters contains the IDs of a node type HostParameters struct { - HostEndpoint string + PrivateKey *wgtypes.Key +} + +// GetPublicKey: gets the public key of the node +func (h *HostParameters) GetPublicKey() string { + return h.PrivateKey.PublicKey().String() } // MeshProviderFactoryParams parameters required to build a mesh provider @@ -130,7 +135,6 @@ type MeshNodeFactoryParams struct { NodeIP net.IP WgPort int Endpoint string - Role conf.NodeType } // MeshBuilder build the hosts mesh node for it to be added to the mesh diff --git a/pkg/robin/requester.go b/pkg/robin/requester.go index e98a228..f349b82 100644 --- a/pkg/robin/requester.go +++ b/pkg/robin/requester.go @@ -72,7 +72,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + configuration := n.Server.GetConfiguration() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout)) defer cancel() meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) diff --git a/pkg/robin/responder.go b/pkg/robin/responder.go index aa7e771..3257b36 100644 --- a/pkg/robin/responder.go +++ b/pkg/robin/responder.go @@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc. return &reply, nil } - -func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { - return &rpc.JoinMeshReply{Success: true}, nil -} diff --git a/pkg/rpc/ctrlserver.pb.go b/pkg/rpc/ctrlserver.pb.go index cbcb061..62e3fd2 100644 --- a/pkg/rpc/ctrlserver.pb.go +++ b/pkg/rpc/ctrlserver.pb.go @@ -20,77 +20,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type MeshNode struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"` - WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"` - Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"` - WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"` -} - -func (x *MeshNode) Reset() { - *x = MeshNode{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MeshNode) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MeshNode) ProtoMessage() {} - -func (x *MeshNode) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead. -func (*MeshNode) Descriptor() ([]byte, []int) { - return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0} -} - -func (x *MeshNode) GetPublicKey() string { - if x != nil { - return x.PublicKey - } - return "" -} - -func (x *MeshNode) GetWgEndpoint() string { - if x != nil { - return x.WgEndpoint - } - return "" -} - -func (x *MeshNode) GetEndpoint() string { - if x != nil { - return x.Endpoint - } - return "" -} - -func (x *MeshNode) GetWgHost() string { - if x != nil { - return x.WgHost - } - return "" -} - type GetMeshRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -102,7 +31,7 @@ type GetMeshRequest struct { func (x *GetMeshRequest) Reset() { *x = GetMeshRequest{} if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] + mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string { func (*GetMeshRequest) ProtoMessage() {} func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] + mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. func (*GetMeshRequest) Descriptor() ([]byte, []int) { - return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} + return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0} } func (x *GetMeshRequest) GetMeshId() string { @@ -149,7 +78,7 @@ type GetMeshReply struct { func (x *GetMeshReply) Reset() { *x = GetMeshReply{} if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] + mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string { func (*GetMeshReply) ProtoMessage() {} func (x *GetMeshReply) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] + mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message { // Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. func (*GetMeshReply) Descriptor() ([]byte, []int) { - return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2} + return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} } func (x *GetMeshReply) GetMesh() []byte { @@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte { return nil } -type JoinMeshRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"` - MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"` -} - -func (x *JoinMeshRequest) Reset() { - *x = JoinMeshRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *JoinMeshRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*JoinMeshRequest) ProtoMessage() {} - -func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead. -func (*JoinMeshRequest) Descriptor() ([]byte, []int) { - return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3} -} - -func (x *JoinMeshRequest) GetChanges() []byte { - if x != nil { - return x.Changes - } - return nil -} - -func (x *JoinMeshRequest) GetMeshId() string { - if x != nil { - return x.MeshId - } - return "" -} - -type JoinMeshReply struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` -} - -func (x *JoinMeshReply) Reset() { - *x = JoinMeshReply{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *JoinMeshReply) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*JoinMeshReply) ProtoMessage() {} - -func (x *JoinMeshReply) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead. -func (*JoinMeshReply) Descriptor() ([]byte, []int) { - return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4} -} - -func (x *JoinMeshReply) GetSuccess() bool { - if x != nil { - return x.Success - } - return false -} - var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{ 0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, - 0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09, - 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67, - 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, - 0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, - 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, - 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28, - 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, - 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f, - 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, - 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, - 0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, - 0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, - 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a, - 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, - 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, - 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, - 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40, - 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63, - 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, - 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, - 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, + 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, + 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f, + 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, + 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, + 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, + 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte { return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData } -var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{ - (*MeshNode)(nil), // 0: rpctypes.MeshNode - (*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest - (*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply - (*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest - (*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply + (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest + (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply } var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{ - 1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest - 3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest - 2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply - 4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type + 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest + 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() { } if !protoimpl.UnsafeEnabled { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MeshNode); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetMeshRequest); i { case 0: return &v.state @@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() { return nil } } - file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetMeshReply); i { case 0: return &v.state @@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() { return nil } } - file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*JoinMeshRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*JoinMeshReply); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } } type x struct{} out := protoimpl.TypeBuilder{ @@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 2, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/rpc/ctrlserver_grpc.pb.go b/pkg/rpc/ctrlserver_grpc.pb.go index 0a1f501..09aab0e 100644 --- a/pkg/rpc/ctrlserver_grpc.pb.go +++ b/pkg/rpc/ctrlserver_grpc.pb.go @@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type MeshCtrlServerClient interface { GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) - JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) } type meshCtrlServerClient struct { @@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, return out, nil } -func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) { - out := new(JoinMeshReply) - err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - // MeshCtrlServerServer is the server API for MeshCtrlServer service. // All implementations must embed UnimplementedMeshCtrlServerServer // for forward compatibility type MeshCtrlServerServer interface { GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) - JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) mustEmbedUnimplementedMeshCtrlServerServer() } @@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct { func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") } -func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) { - return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented") -} func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} // UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. @@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f return interceptor(ctx, in, info, handler) } -func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(JoinMeshRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(MeshCtrlServerServer).JoinMesh(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest)) - } - return interceptor(ctx, in, info, handler) -} - // MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetMesh", Handler: _MeshCtrlServer_GetMesh_Handler, }, - { - MethodName: "JoinMesh", - Handler: _MeshCtrlServer_JoinMesh_Handler, - }, }, Streams: []grpc.StreamDesc{}, Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto", diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index ac7c8cb..ca6fc48 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -45,14 +45,19 @@ func (s *SyncerImpl) Sync(meshId string) error { } nodeNames := s.manager.GetMesh(meshId).GetPeers() - self, err := s.manager.GetSelf(meshId) if err != nil { return err } - neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) + selfPublickey, err := self.GetPublicKey() + + if err != nil { + return err + } + + neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String()) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) for _, node := range randomSubset { @@ -63,7 +68,7 @@ func (s *SyncerImpl) Sync(meshId string) error { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { logging.Log.WriteInfof("Sending to random cluster") - interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) + interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String()) randomSubset = append(randomSubset, interCluster) } @@ -74,7 +79,14 @@ func (s *SyncerImpl) Sync(meshId string) error { go func(i int) error { defer waitGroup.Done() - err := s.requester.SyncMesh(meshId, randomSubset[i]) + + correspondingPeer := s.manager.GetNode(meshId, randomSubset[i]) + + if correspondingPeer == nil { + logging.Log.WriteErrorf("node %s does not exist", randomSubset[i]) + } + + err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint()) return err }(index) } diff --git a/pkg/wg/types.go b/pkg/wg/types.go index 99f22b1..d431b53 100644 --- a/pkg/wg/types.go +++ b/pkg/wg/types.go @@ -1,5 +1,7 @@ package wg +import "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + type WgError struct { msg string } @@ -10,7 +12,7 @@ func (m *WgError) Error() string { type WgInterfaceManipulator interface { // CreateInterface creates a WireGuard interface - CreateInterface(port int) (string, error) + CreateInterface(port int, privateKey *wgtypes.Key) (string, error) // AddAddress adds an address to the given interface name AddAddress(ifName string, addr string) error // RemoveInterface removes the specified interface diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 70ba6af..9a4396a 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -19,7 +19,7 @@ type WgInterfaceManipulatorImpl struct { const hashLength = 6 // CreateInterface creates a WireGuard interface -func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) { +func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) { rtnl, err := lib.NewRtNetlinkConfig() if err != nil { @@ -44,14 +44,8 @@ func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) { return "", fmt.Errorf("failed to create link: %w", err) } - privateKey, err := wgtypes.GeneratePrivateKey() - - if err != nil { - return "", fmt.Errorf("failed to create private key: %w", err) - } - var cfg wgtypes.Config = wgtypes.Config{ - PrivateKey: &privateKey, + PrivateKey: privKey, ListenPort: &port, }