diff --git a/management/management_test.go b/management/management_test.go index e5cfac3ad..499e1b528 100644 --- a/management/management_test.go +++ b/management/management_test.go @@ -5,7 +5,6 @@ import ( pb "github.com/golang/protobuf/proto" //nolint log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/encryption" - "io" "io/ioutil" "math/rand" "net" @@ -34,7 +33,6 @@ var _ = Describe("Management service", func() { var ( addr string server *grpc.Server - tmpDir string dataDir string client mgmtProto.ManagementServiceClient serverPubKey wgtypes.Key @@ -67,7 +65,7 @@ var _ = Describe("Management service", func() { server.Stop() err := conn.Close() Expect(err).NotTo(HaveOccurred()) - err = os.RemoveAll(tmpDir) + err = os.RemoveAll(dataDir) Expect(err).NotTo(HaveOccurred()) }) @@ -237,29 +235,30 @@ var _ = Describe("Management service", func() { wg := sync2.WaitGroup{} wg.Add(initialPeers + initialPeers*additionalPeers) + + var clients []mgmtProto.ManagementService_SyncClient for _, peer := range peers { messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) Expect(err).NotTo(HaveOccurred()) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer) Expect(err).NotTo(HaveOccurred()) + // open stream + sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peer.PublicKey().String(), + Body: encryptedBytes, + }) + Expect(err).NotTo(HaveOccurred()) + clients = append(clients, sync) + // receive stream peer := peer go func() { - - // open stream - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: peer.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) for { encryptedResponse := &mgmtProto.EncryptedMessage{} err = sync.RecvMsg(encryptedResponse) - if err == io.EOF { + if err != nil { break - } else if err != nil { - Expect(err).NotTo(HaveOccurred()) } decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer) Expect(err).NotTo(HaveOccurred()) @@ -268,7 +267,6 @@ var _ = Describe("Management service", func() { err = pb.Unmarshal(decryptedBytes, resp) Expect(err).NotTo(HaveOccurred()) wg.Done() - } }() } @@ -284,6 +282,11 @@ var _ = Describe("Management service", func() { wg.Wait() + for _, syncClient := range clients { + err := syncClient.CloseSend() + Expect(err).NotTo(HaveOccurred()) + } + }) }) })