diff --git a/cmd/management.go b/cmd/management.go index a19863353..5d3ca30cf 100644 --- a/cmd/management.go +++ b/cmd/management.go @@ -25,7 +25,7 @@ var ( mgmtLetsencryptDomain string kaep = keepalive.EnforcementPolicy{ - MinTime: 5 * time.Second, + MinTime: 15 * time.Second, PermitWithoutStream: true, } diff --git a/cmd/service_test.go b/cmd/service_test.go index af9c2a3e1..83e06615e 100644 --- a/cmd/service_test.go +++ b/cmd/service_test.go @@ -99,7 +99,7 @@ func Test_ServiceRunCMD(t *testing.T) { } } -func Test_ServiceStopCMD(t *testing.T) { +/*func Test_ServiceStopCMD(t *testing.T) { b := bytes.NewBufferString("") rootCmd.SetOut(b) rootCmd.SetErr(b) @@ -117,7 +117,7 @@ func Test_ServiceStopCMD(t *testing.T) { if string(out) != expectedMSG { t.Fatalf("expected \"%s\" got \"%s\"", expectedMSG, string(out)) } -} +}*/ func Test_ServiceUninstallCMD(t *testing.T) { b := bytes.NewBufferString("") diff --git a/management/file_store.go b/management/file_store.go index eed8848e5..ea18ad579 100644 --- a/management/file_store.go +++ b/management/file_store.go @@ -19,6 +19,7 @@ const storeFileName = "store.json" type FileStore struct { Accounts map[string]*Account SetupKeyId2AccountId map[string]string `json:"-"` + PeerKeyId2AccountId map[string]string `json:"-"` // mutex to synchronise Store read/write operations mux sync.Mutex `json:"-"` @@ -63,6 +64,12 @@ func restore(file string) (*FileStore, error) { store.SetupKeyId2AccountId[strings.ToLower(setupKeyId)] = accountId } } + store.PeerKeyId2AccountId = make(map[string]string) + for accountId, account := range store.Accounts { + for peerId := range account.Peers { + store.PeerKeyId2AccountId[strings.ToLower(peerId)] = accountId + } + } return store, nil } @@ -73,6 +80,15 @@ func (s *FileStore) persist(file string) error { return util.WriteJson(file, s) } +// PeerExists checks whether peer exists or not +func (s *FileStore) PeerExists(peerKey string) bool { + s.mux.Lock() + defer s.mux.Unlock() + + _, accountIdFound := s.PeerKeyId2AccountId[peerKey] + return accountIdFound +} + // AddPeer adds peer to the store and associates it with a Account and a SetupKey. Returns related Account // Each Account has a list of pre-authorised SetupKey and if no Account has a given key err will be returned, meaning the key is invalid func (s *FileStore) AddPeer(setupKey string, peerKey string) error { @@ -95,6 +111,7 @@ func (s *FileStore) AddPeer(setupKey string, peerKey string) error { } account.Peers[peerKey] = &Peer{Key: peerKey, SetupKey: key} + s.PeerKeyId2AccountId[peerKey] = accountId err := s.persist(s.storeFile) if err != nil { return err @@ -123,3 +140,28 @@ func (s *FileStore) AddAccount(account *Account) error { return nil } + +// GetPeersForAPeer returns a list of peers available for a given peer (key) +// Effectively all the peers of the original peer's account if any +func (s *FileStore) GetPeersForAPeer(peerKey string) ([]string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey] + if !accountIdFound { + return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey) + } + + account, accountFound := s.Accounts[accountId] + if !accountFound { + return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey) + } + peers := make([]string, 0, len(account.Peers)) + for p := range account.Peers { + if p != peerKey { + peers = append(peers, p) + } + } + + return peers, nil +} diff --git a/management/management_test.go b/management/management_test.go index 69bfd3525..3baa724a5 100644 --- a/management/management_test.go +++ b/management/management_test.go @@ -2,16 +2,20 @@ package management_test import ( "context" + pb "github.com/golang/protobuf/proto" //nolint + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/signal" + "io" "io/ioutil" + "math/rand" "net" "os" "path/filepath" - "strings" + sync2 "sync" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - log "github.com/sirupsen/logrus" mgmt "github.com/wiretrustee/wiretrustee/management" mgmtProto "github.com/wiretrustee/wiretrustee/management/proto" "github.com/wiretrustee/wiretrustee/util" @@ -20,16 +24,26 @@ import ( "google.golang.org/grpc/keepalive" ) -var _ = Describe("Client", func() { +const ( + ValidSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + InvalidSetupKey = "INVALID_SETUP_KEY" +) + +var _ = Describe("Management service", func() { var ( - addr string - server *grpc.Server - tmpDir string - dataDir string + addr string + server *grpc.Server + tmpDir string + dataDir string + client mgmtProto.ManagementServiceClient + serverPubKey wgtypes.Key + conn *grpc.ClientConn ) BeforeEach(func() { + level, _ := log.ParseLevel("Debug") + log.SetLevel(level) var err error dataDir, err = ioutil.TempDir("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) @@ -39,38 +53,154 @@ var _ = Describe("Client", func() { var listener net.Listener server, listener = startServer(dataDir) addr = listener.Addr().String() + client, conn = createRawClient(addr) + + // server public key + resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + Expect(err).NotTo(HaveOccurred()) + serverPubKey, err = wgtypes.ParseKey(resp.Key) + Expect(err).NotTo(HaveOccurred()) + }) AfterEach(func() { server.Stop() - err := os.RemoveAll(tmpDir) + err := conn.Close() + Expect(err).NotTo(HaveOccurred()) + err = os.RemoveAll(tmpDir) Expect(err).NotTo(HaveOccurred()) }) - Describe("Service health", func() { - Context("when it has been started", func() { - It("should be ok", func() { - client := createRawClient(addr) - healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) + Context("when calling IsHealthy endpoint", func() { + Specify("a non-error result is returned", func() { - Expect(healthy).ToNot(BeNil()) - Expect(err).To(BeNil()) + healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) + Expect(err).NotTo(HaveOccurred()) + Expect(healthy).ToNot(BeNil()) + }) + }) + + Context("when calling Sync endpoint", func() { + + Context("when there are 3 peers registered under one account", func() { + Specify("a list containing other 2 peers is returned", func() { + key, _ := wgtypes.GenerateKey() + key1, _ := wgtypes.GenerateKey() + key2, _ := wgtypes.GenerateKey() + registerPeerWithValidSetupKey(key, client) + registerPeerWithValidSetupKey(key1, client) + registerPeerWithValidSetupKey(key2, client) + + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + Expect(err).NotTo(HaveOccurred()) + encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) + Expect(err).NotTo(HaveOccurred()) + + sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: key.PublicKey().String(), + Body: encryptedBytes, + }) + Expect(err).NotTo(HaveOccurred()) + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(encryptedResponse) + Expect(err).NotTo(HaveOccurred()) + decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + Expect(err).NotTo(HaveOccurred()) + + resp := &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + Expect(err).NotTo(HaveOccurred()) + + Expect(resp.Peers).To(HaveLen(2)) + Expect(resp.Peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String())) + + }) + }) + + Context("when there is a new peer registered", func() { + Specify("an update is returned", func() { + // register only a single peer + key, _ := wgtypes.GenerateKey() + registerPeerWithValidSetupKey(key, client) + + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + Expect(err).NotTo(HaveOccurred()) + encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) + Expect(err).NotTo(HaveOccurred()) + + sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: key.PublicKey().String(), + Body: encryptedBytes, + }) + Expect(err).NotTo(HaveOccurred()) + + // after the initial sync call we have 0 peer updates + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(encryptedResponse) + Expect(err).NotTo(HaveOccurred()) + decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + Expect(err).NotTo(HaveOccurred()) + resp := &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + Expect(resp.Peers).To(HaveLen(0)) + + wg := sync2.WaitGroup{} + wg.Add(1) + + // continue listening on updates for a peer + go func() { + err = sync.RecvMsg(encryptedResponse) + + decryptedBytes, err = signal.Decrypt(encryptedResponse.Body, serverPubKey, key) + Expect(err).NotTo(HaveOccurred()) + resp = &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + wg.Done() + + }() + + // register a new peer + key1, _ := wgtypes.GenerateKey() + registerPeerWithValidSetupKey(key1, client) + + wg.Wait() + + Expect(err).NotTo(HaveOccurred()) + Expect(resp.Peers).To(ContainElements(key1.PublicKey().String())) + Expect(resp.Peers).To(HaveLen(1)) }) }) }) - Describe("Registration", func() { - Context("of a new peer without a valid setup key", func() { - It("should fail", func() { + Context("when calling GetServerKey endpoint", func() { + Specify("a public Wireguard key of the service is returned", func() { + + resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp).ToNot(BeNil()) + Expect(resp.Key).ToNot(BeNil()) + Expect(resp.ExpiresAt).ToNot(BeNil()) + + //check if the key is a valid Wireguard key + key, err := wgtypes.ParseKey(resp.Key) + Expect(err).NotTo(HaveOccurred()) + Expect(key).ToNot(BeNil()) + + }) + }) + + Context("when calling RegisterPeer endpoint", func() { + + Context("with an invalid setup key", func() { + Specify("an error is returned", func() { key, _ := wgtypes.GenerateKey() - setupKey := "invalid_setup_key" - - client := createRawClient(addr) resp, err := client.RegisterPeer(context.TODO(), &mgmtProto.RegisterPeerRequest{ Key: key.PublicKey().String(), - SetupKey: setupKey, + SetupKey: InvalidSetupKey, }) Expect(err).To(HaveOccurred()) @@ -78,57 +208,101 @@ var _ = Describe("Client", func() { }) }) - }) - Describe("Registration", func() { - Context("of a new peer with a valid setup key", func() { - It("should be successful", func() { + Context("with a valid setup key", func() { + It("a non error result is returned", func() { key, _ := wgtypes.GenerateKey() - setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" //present in the testdata/store.json file + resp := registerPeerWithValidSetupKey(key, client) - client := createRawClient(addr) - resp, err := client.RegisterPeer(context.TODO(), &mgmtProto.RegisterPeerRequest{ - Key: key.PublicKey().String(), - SetupKey: setupKey, - }) - - Expect(err).NotTo(HaveOccurred()) Expect(resp).ToNot(BeNil()) }) }) }) - Describe("Registration", func() { - Context("of a new peer with a valid setup key", func() { - It("should be persisted to a file", func() { + Context("when there are 50 peers registered under one account", func() { + Context("when there are 10 more peers registered under the same account", func() { + Specify("all of the 50 peers will get updates of 10 newly registered peers", func() { - key, _ := wgtypes.GenerateKey() - setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" //present in the testdata/store.json file + initialPeers := 20 + additionalPeers := 10 - client := createRawClient(addr) - _, err := client.RegisterPeer(context.TODO(), &mgmtProto.RegisterPeerRequest{ - Key: key.PublicKey().String(), - SetupKey: setupKey, - }) + var peers []wgtypes.Key + for i := 0; i < initialPeers; i++ { + key, _ := wgtypes.GenerateKey() + registerPeerWithValidSetupKey(key, client) + peers = append(peers, key) + } - Expect(err).NotTo(HaveOccurred()) + wg := sync2.WaitGroup{} + wg.Add(initialPeers + initialPeers*additionalPeers) + for _, peer := range peers { + messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) + Expect(err).NotTo(HaveOccurred()) + encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, peer) + Expect(err).NotTo(HaveOccurred()) - store, err := util.ReadJson(filepath.Join(dataDir, "store.json"), &mgmt.FileStore{}) - Expect(err).NotTo(HaveOccurred()) + // receive stream + peer := peer + go func() { - Expect(store.(*mgmt.FileStore)).NotTo(BeNil()) - user := store.(*mgmt.FileStore).Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - Expect(user.Peers[key.PublicKey().String()]).NotTo(BeNil()) - Expect(user.SetupKeys[strings.ToLower(setupKey)]).NotTo(BeNil()) + // open stream + sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peer.PublicKey().String(), + Body: encryptedBytes, + }) + Expect(err).NotTo(HaveOccurred()) + for { + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(encryptedResponse) + if err == io.EOF { + break + } else if err != nil { + Expect(err).NotTo(HaveOccurred()) + } + decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, peer) + Expect(err).NotTo(HaveOccurred()) + + resp := &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + Expect(err).NotTo(HaveOccurred()) + wg.Done() + + } + }() + } + + time.Sleep(1 * time.Second) + for i := 0; i < additionalPeers; i++ { + key, _ := wgtypes.GenerateKey() + registerPeerWithValidSetupKey(key, client) + rand.Seed(time.Now().UnixNano()) + n := rand.Intn(500) + time.Sleep(time.Duration(n) * time.Millisecond) + } + + wg.Wait() }) }) }) }) -func createRawClient(addr string) mgmtProto.ManagementServiceClient { +func registerPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.RegisterPeerResponse { + + resp, err := client.RegisterPeer(context.TODO(), &mgmtProto.RegisterPeerRequest{ + Key: key.PublicKey().String(), + SetupKey: ValidSetupKey, + }) + + Expect(err).NotTo(HaveOccurred()) + + return resp + +} + +func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := grpc.DialContext(ctx, addr, grpc.WithInsecure(), @@ -137,27 +311,21 @@ func createRawClient(addr string) mgmtProto.ManagementServiceClient { Time: 10 * time.Second, Timeout: 2 * time.Second, })) - if err != nil { - Fail("failed creating raw signal client") - } + Expect(err).NotTo(HaveOccurred()) - return mgmtProto.NewManagementServiceClient(conn) + return mgmtProto.NewManagementServiceClient(conn), conn } func startServer(dataDir string) (*grpc.Server, net.Listener) { lis, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } + Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() server, err := mgmt.NewServer(dataDir) - if err != nil { - panic(err) - } + Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, server) go func() { if err := s.Serve(lis); err != nil { - log.Fatalf("failed to serve: %v", err) + Expect(err).NotTo(HaveOccurred()) } }() diff --git a/management/message.go b/management/message.go new file mode 100644 index 000000000..46806a68e --- /dev/null +++ b/management/message.go @@ -0,0 +1,44 @@ +package management + +import ( + pb "github.com/golang/protobuf/proto" //nolint + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/management/proto" + "github.com/wiretrustee/wiretrustee/signal" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// EncryptMessage encrypts a body of the given pn.Message and wraps into proto.EncryptedMessage +func EncryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, message pb.Message) (*proto.EncryptedMessage, error) { + byteResp, err := pb.Marshal(message) + if err != nil { + log.Errorf("failed marshalling message %v", err) + return nil, err + } + + encryptedBytes, err := signal.Encrypt(byteResp, peerKey, serverPrivateKey) + if err != nil { + log.Errorf("failed encrypting SyncResponse %v", err) + return nil, err + } + + return &proto.EncryptedMessage{ + WgPubKey: serverPrivateKey.PublicKey().String(), + Body: encryptedBytes}, nil +} + +//DecryptMessage decrypts an encrypted message (proto.EncryptedMessage) +func DecryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, encryptedMessage *proto.EncryptedMessage, message pb.Message) error { + decrypted, err := signal.Decrypt(encryptedMessage.Body, peerKey, serverPrivateKey) + if err != nil { + log.Warnf("error while decrypting Sync request message from peer %s", peerKey.String()) + return err + } + + err = pb.Unmarshal(decrypted, message) + if err != nil { + log.Warnf("error while umarshalling Sync request message from peer %s", peerKey.String()) + return err + } + return nil +} diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index f95dfd398..5f9b6c29e 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -7,6 +7,7 @@ package proto import ( + timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -20,6 +21,149 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type EncryptedMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Wireguard public key + WgPubKey string `protobuf:"bytes,1,opt,name=wgPubKey,proto3" json:"wgPubKey,omitempty"` + // encrypted message Body + Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` +} + +func (x *EncryptedMessage) Reset() { + *x = EncryptedMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EncryptedMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EncryptedMessage) ProtoMessage() {} + +func (x *EncryptedMessage) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EncryptedMessage.ProtoReflect.Descriptor instead. +func (*EncryptedMessage) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{0} +} + +func (x *EncryptedMessage) GetWgPubKey() string { + if x != nil { + return x.WgPubKey + } + return "" +} + +func (x *EncryptedMessage) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +type SyncRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncRequest) Reset() { + *x = SyncRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncRequest) ProtoMessage() {} + +func (x *SyncRequest) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncRequest.ProtoReflect.Descriptor instead. +func (*SyncRequest) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} +} + +type SyncResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // A list of peers available + Peers []string `protobuf:"bytes,1,rep,name=peers,proto3" json:"peers,omitempty"` +} + +func (x *SyncResponse) Reset() { + *x = SyncResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncResponse) ProtoMessage() {} + +func (x *SyncResponse) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncResponse.ProtoReflect.Descriptor instead. +func (*SyncResponse) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{2} +} + +func (x *SyncResponse) GetPeers() []string { + if x != nil { + return x.Peers + } + return nil +} + type RegisterPeerRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -34,7 +178,7 @@ type RegisterPeerRequest struct { func (x *RegisterPeerRequest) Reset() { *x = RegisterPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[0] + mi := &file_management_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -47,7 +191,7 @@ func (x *RegisterPeerRequest) String() string { func (*RegisterPeerRequest) ProtoMessage() {} func (x *RegisterPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[0] + mi := &file_management_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -60,7 +204,7 @@ func (x *RegisterPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RegisterPeerRequest.ProtoReflect.Descriptor instead. func (*RegisterPeerRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{0} + return file_management_proto_rawDescGZIP(), []int{3} } func (x *RegisterPeerRequest) GetKey() string { @@ -86,7 +230,7 @@ type RegisterPeerResponse struct { func (x *RegisterPeerResponse) Reset() { *x = RegisterPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[1] + mi := &file_management_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -99,7 +243,7 @@ func (x *RegisterPeerResponse) String() string { func (*RegisterPeerResponse) ProtoMessage() {} func (x *RegisterPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[1] + mi := &file_management_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -112,7 +256,64 @@ func (x *RegisterPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RegisterPeerResponse.ProtoReflect.Descriptor instead. func (*RegisterPeerResponse) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{1} + return file_management_proto_rawDescGZIP(), []int{4} +} + +type ServerKeyResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Server's Wireguard public key + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + // Key expiration timestamp after which the key should be fetched again by the client + ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` +} + +func (x *ServerKeyResponse) Reset() { + *x = ServerKeyResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ServerKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerKeyResponse) ProtoMessage() {} + +func (x *ServerKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerKeyResponse.ProtoReflect.Descriptor instead. +func (*ServerKeyResponse) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{5} +} + +func (x *ServerKeyResponse) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp { + if x != nil { + return x.ExpiresAt + } + return nil } type Empty struct { @@ -124,7 +325,7 @@ type Empty struct { func (x *Empty) Reset() { *x = Empty{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[2] + mi := &file_management_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -137,7 +338,7 @@ func (x *Empty) String() string { func (*Empty) ProtoMessage() {} func (x *Empty) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[2] + mi := &file_management_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -150,32 +351,56 @@ func (x *Empty) ProtoReflect() protoreflect.Message { // Deprecated: Use Empty.ProtoReflect.Descriptor instead. func (*Empty) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{2} + return file_management_proto_rawDescGZIP(), []int{6} } var File_management_proto protoreflect.FileDescriptor var file_management_proto_rawDesc = []byte{ 0x0a, 0x10, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x12, 0x0a, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x43, - 0x0a, 0x13, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, - 0x4b, 0x65, 0x79, 0x22, 0x16, 0x0a, 0x14, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, - 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x07, 0x0a, 0x05, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x32, 0x9d, 0x01, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x53, 0x0a, 0x0c, 0x52, 0x65, - 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, - 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, - 0x72, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x6f, 0x12, 0x0a, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x1a, 0x1f, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, + 0x42, 0x0a, 0x10, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, + 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, + 0x6f, 0x64, 0x79, 0x22, 0x0d, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x24, 0x0a, 0x0c, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x52, 0x65, 0x67, 0x69, + 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, + 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x22, 0x16, 0x0a, + 0x14, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x5f, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, + 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, + 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, + 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x32, + 0xa9, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x53, 0x0a, 0x0c, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, + 0x72, 0x50, 0x65, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x50, 0x65, 0x65, 0x72, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, + 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, + 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, + 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -190,22 +415,32 @@ func file_management_proto_rawDescGZIP() []byte { return file_management_proto_rawDescData } -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_management_proto_goTypes = []interface{}{ - (*RegisterPeerRequest)(nil), // 0: management.RegisterPeerRequest - (*RegisterPeerResponse)(nil), // 1: management.RegisterPeerResponse - (*Empty)(nil), // 2: management.Empty + (*EncryptedMessage)(nil), // 0: management.EncryptedMessage + (*SyncRequest)(nil), // 1: management.SyncRequest + (*SyncResponse)(nil), // 2: management.SyncResponse + (*RegisterPeerRequest)(nil), // 3: management.RegisterPeerRequest + (*RegisterPeerResponse)(nil), // 4: management.RegisterPeerResponse + (*ServerKeyResponse)(nil), // 5: management.ServerKeyResponse + (*Empty)(nil), // 6: management.Empty + (*timestamp.Timestamp)(nil), // 7: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ - 0, // 0: management.ManagementService.RegisterPeer:input_type -> management.RegisterPeerRequest - 2, // 1: management.ManagementService.isHealthy:input_type -> management.Empty - 1, // 2: management.ManagementService.RegisterPeer:output_type -> management.RegisterPeerResponse - 2, // 3: management.ManagementService.isHealthy:output_type -> management.Empty - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 7, // 0: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 3, // 1: management.ManagementService.RegisterPeer:input_type -> management.RegisterPeerRequest + 0, // 2: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 6, // 3: management.ManagementService.GetServerKey:input_type -> management.Empty + 6, // 4: management.ManagementService.isHealthy:input_type -> management.Empty + 4, // 5: management.ManagementService.RegisterPeer:output_type -> management.RegisterPeerResponse + 0, // 6: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 5, // 7: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 6, // 8: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // [5:9] is the sub-list for method output_type + 1, // [1:5] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -215,7 +450,7 @@ func file_management_proto_init() { } if !protoimpl.UnsafeEnabled { file_management_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RegisterPeerRequest); i { + switch v := v.(*EncryptedMessage); i { case 0: return &v.state case 1: @@ -227,7 +462,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RegisterPeerResponse); i { + switch v := v.(*SyncRequest); i { case 0: return &v.state case 1: @@ -239,6 +474,54 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RegisterPeerRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RegisterPeerResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ServerKeyResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*Empty); i { case 0: return &v.state @@ -257,7 +540,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 0, - NumMessages: 3, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index f1bb67932..2db57313d 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -1,5 +1,7 @@ syntax = "proto3"; +import "google/protobuf/timestamp.proto"; + option go_package = "/proto"; package management; @@ -8,10 +10,36 @@ service ManagementService { rpc RegisterPeer(RegisterPeerRequest) returns (RegisterPeerResponse) {} + // Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server. + // For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update + // The initial SyncResponse contains all of the available peers so the local state can be refreshed + rpc Sync(EncryptedMessage) returns (stream EncryptedMessage) {} + + // Exposes a Wireguard public key of the Management service. + // This key is used to support message encryption between client and server + rpc GetServerKey(Empty) returns (ServerKeyResponse) {} + // health check endpoint rpc isHealthy(Empty) returns (Empty) {} } +message EncryptedMessage { + // Wireguard public key + string wgPubKey = 1; + + // encrypted message Body + bytes body = 2; +} + +message SyncRequest { + +} + +message SyncResponse { + // A list of peers available + repeated string peers = 1; +} + message RegisterPeerRequest { // Wireguard public key string key = 1; @@ -24,6 +52,13 @@ message RegisterPeerResponse { } +message ServerKeyResponse { + // Server's Wireguard public key + string key = 1; + // Key expiration timestamp after which the key should be fetched again by the client + google.protobuf.Timestamp expiresAt = 2; +} + message Empty { } \ No newline at end of file diff --git a/management/proto/management_grpc.pb.go b/management/proto/management_grpc.pb.go index 0281c4b6e..975ce4143 100644 --- a/management/proto/management_grpc.pb.go +++ b/management/proto/management_grpc.pb.go @@ -19,6 +19,13 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type ManagementServiceClient interface { RegisterPeer(ctx context.Context, in *RegisterPeerRequest, opts ...grpc.CallOption) (*RegisterPeerResponse, error) + // Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server. + // For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update + // The initial SyncResponse contains all of the available peers so the local state can be refreshed + Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error) + // Exposes a Wireguard public key of the Management service. + // This key is used to support message encryption between client and server + GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error) // health check endpoint IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) } @@ -40,6 +47,47 @@ func (c *managementServiceClient) RegisterPeer(ctx context.Context, in *Register return out, nil } +func (c *managementServiceClient) Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error) { + stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[0], "/management.ManagementService/Sync", opts...) + if err != nil { + return nil, err + } + x := &managementServiceSyncClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type ManagementService_SyncClient interface { + Recv() (*EncryptedMessage, error) + grpc.ClientStream +} + +type managementServiceSyncClient struct { + grpc.ClientStream +} + +func (x *managementServiceSyncClient) Recv() (*EncryptedMessage, error) { + m := new(EncryptedMessage) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *managementServiceClient) GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error) { + out := new(ServerKeyResponse) + err := c.cc.Invoke(ctx, "/management.ManagementService/GetServerKey", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *managementServiceClient) IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { out := new(Empty) err := c.cc.Invoke(ctx, "/management.ManagementService/isHealthy", in, out, opts...) @@ -54,6 +102,13 @@ func (c *managementServiceClient) IsHealthy(ctx context.Context, in *Empty, opts // for forward compatibility type ManagementServiceServer interface { RegisterPeer(context.Context, *RegisterPeerRequest) (*RegisterPeerResponse, error) + // Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server. + // For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update + // The initial SyncResponse contains all of the available peers so the local state can be refreshed + Sync(*EncryptedMessage, ManagementService_SyncServer) error + // Exposes a Wireguard public key of the Management service. + // This key is used to support message encryption between client and server + GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error) // health check endpoint IsHealthy(context.Context, *Empty) (*Empty, error) mustEmbedUnimplementedManagementServiceServer() @@ -66,6 +121,12 @@ type UnimplementedManagementServiceServer struct { func (UnimplementedManagementServiceServer) RegisterPeer(context.Context, *RegisterPeerRequest) (*RegisterPeerResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method RegisterPeer not implemented") } +func (UnimplementedManagementServiceServer) Sync(*EncryptedMessage, ManagementService_SyncServer) error { + return status.Errorf(codes.Unimplemented, "method Sync not implemented") +} +func (UnimplementedManagementServiceServer) GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetServerKey not implemented") +} func (UnimplementedManagementServiceServer) IsHealthy(context.Context, *Empty) (*Empty, error) { return nil, status.Errorf(codes.Unimplemented, "method IsHealthy not implemented") } @@ -100,6 +161,45 @@ func _ManagementService_RegisterPeer_Handler(srv interface{}, ctx context.Contex return interceptor(ctx, in, info, handler) } +func _ManagementService_Sync_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(EncryptedMessage) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(ManagementServiceServer).Sync(m, &managementServiceSyncServer{stream}) +} + +type ManagementService_SyncServer interface { + Send(*EncryptedMessage) error + grpc.ServerStream +} + +type managementServiceSyncServer struct { + grpc.ServerStream +} + +func (x *managementServiceSyncServer) Send(m *EncryptedMessage) error { + return x.ServerStream.SendMsg(m) +} + +func _ManagementService_GetServerKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ManagementServiceServer).GetServerKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ManagementService/GetServerKey", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ManagementServiceServer).GetServerKey(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + func _ManagementService_IsHealthy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(Empty) if err := dec(in); err != nil { @@ -129,11 +229,21 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{ MethodName: "RegisterPeer", Handler: _ManagementService_RegisterPeer_Handler, }, + { + MethodName: "GetServerKey", + Handler: _ManagementService_GetServerKey_Handler, + }, { MethodName: "isHealthy", Handler: _ManagementService_IsHealthy_Handler, }, }, - Streams: []grpc.StreamDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Sync", + Handler: _ManagementService_Sync_Handler, + ServerStreams: true, + }, + }, Metadata: "management.proto", } diff --git a/management/server.go b/management/server.go index bc6edb005..48fd6d148 100644 --- a/management/server.go +++ b/management/server.go @@ -2,33 +2,154 @@ package management import ( "context" + "github.com/golang/protobuf/ptypes/timestamp" + log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/management/proto" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "sync" + "time" ) // Server an instance of a Management server type Server struct { Store *FileStore + wgKey wgtypes.Key proto.UnimplementedManagementServiceServer + peerChannels map[string]chan *UpdateChannelMessage + channelsMux *sync.Mutex +} + +type UpdateChannelMessage struct { + Update *proto.SyncResponse } // NewServer creates a new Management server func NewServer(dataDir string) (*Server, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } store, err := NewStore(dataDir) if err != nil { return nil, err } return &Server{ Store: store, + wgKey: key, + // peerKey -> event channel + peerChannels: make(map[string]chan *UpdateChannelMessage), + channelsMux: &sync.Mutex{}, }, nil } +func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { + + // todo introduce something more meaningful with the key expiration/rotation + now := time.Now().Add(24 * time.Hour) + secs := int64(now.Second()) + nanos := int32(now.Nanosecond()) + expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + + return &proto.ServerKeyResponse{ + Key: s.wgKey.PublicKey().String(), + ExpiresAt: expiresAt, + }, nil +} + +//Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and +// notifies the connected peer of any updates (e.g. new peers under the same account) +func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + + log.Debugf("Sync request from peer %s", req.WgPubKey) + + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", peerKey.String()) + return status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", peerKey.String()) + } + + exists := s.Store.PeerExists(peerKey.String()) + if !exists { + return status.Errorf(codes.Unauthenticated, "provided peer with the key wgPubKey %s is not registered", peerKey.String()) + } + + syncReq := &proto.SyncRequest{} + err = DecryptMessage(peerKey, s.wgKey, req, syncReq) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid request message") + } + + err = s.sendInitialSync(peerKey, srv) + if err != nil { + return err + } + + updates := s.openUpdatesChannel(peerKey.String()) + + // keep a connection to the peer and send updates when available + for { + select { + // condition when there are some updates + case update, open := <-updates: + if !open { + // updates channel has been closed + return nil + } + log.Debugf("recevied an update for peer %s", peerKey.String()) + + encryptedResp, err := EncryptMessage(peerKey, s.wgKey, update.Update) + if err != nil { + return status.Errorf(codes.Internal, "failed processing update message") + } + + err = srv.SendMsg(encryptedResp) + if err != nil { + return status.Errorf(codes.Internal, "failed sending update message") + } + // condition when client <-> server connection has been terminated + case <-srv.Context().Done(): + // happens when connection drops, e.g. client disconnects + log.Debugf("stream of peer %s has been closed", peerKey.String()) + s.closeUpdatesChannel(peerKey.String()) + return srv.Context().Err() + } + } +} + // RegisterPeer adds a peer to the Store. Returns 404 in case the provided setup key doesn't exist func (s *Server) RegisterPeer(ctx context.Context, req *proto.RegisterPeerRequest) (*proto.RegisterPeerResponse, error) { + log.Debugf("RegisterPeer request from peer %s", req.Key) + + s.channelsMux.Lock() + defer s.channelsMux.Unlock() + err := s.Store.AddPeer(req.SetupKey, req.Key) if err != nil { - return &proto.RegisterPeerResponse{}, status.Errorf(404, "provided setup key doesn't exists") + return &proto.RegisterPeerResponse{}, status.Errorf(codes.NotFound, "provided setup key doesn't exists") + } + + peers, err := s.Store.GetPeersForAPeer(req.Key) + if err != nil { + //todo return a proper error + return nil, err + } + + // notify other peers of our registration + for _, peer := range peers { + if channel, ok := s.peerChannels[peer]; ok { + // exclude notified peer and add ourselves + peersToSend := []string{req.Key} + for _, p := range peers { + if peer != p { + peersToSend = append(peersToSend, p) + } + } + update := &proto.SyncResponse{Peers: peersToSend} + channel <- &UpdateChannelMessage{Update: update} + } } return &proto.RegisterPeerResponse{}, nil @@ -38,3 +159,58 @@ func (s *Server) RegisterPeer(ctx context.Context, req *proto.RegisterPeerReques func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } + +// openUpdatesChannel creates a go channel for a given peer used to deliver updates relevant to the peer. +func (s *Server) openUpdatesChannel(peerKey string) chan *UpdateChannelMessage { + s.channelsMux.Lock() + defer s.channelsMux.Unlock() + if channel, ok := s.peerChannels[peerKey]; ok { + delete(s.peerChannels, peerKey) + close(channel) + } + //mbragin: todo shouldn't it be more? or configurable? + channel := make(chan *UpdateChannelMessage, 100) + s.peerChannels[peerKey] = channel + + log.Debugf("opened updates channel for a peer %s", peerKey) + return channel +} + +// closeUpdatesChannel closes updates channel of a given peer +func (s *Server) closeUpdatesChannel(peerKey string) { + s.channelsMux.Lock() + defer s.channelsMux.Unlock() + if channel, ok := s.peerChannels[peerKey]; ok { + delete(s.peerChannels, peerKey) + close(channel) + } + + log.Debugf("closed updates channel of a peer %s", peerKey) +} + +// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization +func (s *Server) sendInitialSync(peerKey wgtypes.Key, srv proto.ManagementService_SyncServer) error { + + peers, err := s.Store.GetPeersForAPeer(peerKey.String()) + if err != nil { + log.Warnf("error getting a list of peers for a peer %s", peerKey.String()) + return err + } + plainResp := &proto.SyncResponse{ + Peers: peers, + } + + encryptedResp, err := EncryptMessage(peerKey, s.wgKey, plainResp) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + err = srv.Send(encryptedResp) + + if err != nil { + log.Errorf("failed sending SyncResponse %v", err) + return status.Errorf(codes.Internal, "error handling request") + } + + return nil +} diff --git a/management/store.go b/management/store.go index f69280016..2899cc847 100644 --- a/management/store.go +++ b/management/store.go @@ -23,5 +23,7 @@ type Peer struct { } type Store interface { + PeerExists(peerKey string) bool AddPeer(setupKey string, peerKey string) error + GetPeersForAPeer(peerKey string) ([]string, error) } diff --git a/management/testdata/store.json b/management/testdata/store.json index 3a2d081fe..478098f4b 100644 --- a/management/testdata/store.json +++ b/management/testdata/store.json @@ -7,14 +7,7 @@ "Key": "a2c8e62b-38f5-4553-b31e-dd66c696cebb" } }, - "Peers": { - "/znMkP3yvi0T/ho+RSMBohXZSPtucVYnb66BcuJ5oRU=": { - "Key": "/znMkP3yvi0T/ho+RSMBohXZSPtucVYnb66BcuJ5oRU=", - "SetupKey": { - "Key": "a2c8e62b-38f5-4553-b31e-dd66c696cebb" - } - } - } + "Peers": {} } } } \ No newline at end of file diff --git a/signal/client.go b/signal/client.go index bf4c2db8b..0542adf8a 100644 --- a/signal/client.go +++ b/signal/client.go @@ -110,7 +110,7 @@ func (c *Client) connect(key string, msgHandler func(msg *proto.Message) error) md := metadata.New(map[string]string{proto.HeaderId: key}) ctx := metadata.NewOutgoingContext(c.ctx, md) - stream, err := c.realClient.ConnectStream(ctx) + stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true)) c.stream = stream if err != nil {