mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 00:54:01 +01:00
Extract common server encryption logic (#65)
* refactor: extract common message encryption logic * refactor: move letsencrypt logic to common * refactor: rename common package to encryption * test: add encryption tests
This commit is contained in:
parent
c98be683bf
commit
2172d6f1b9
@ -1,21 +1,18 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
mgmt "github.com/wiretrustee/wiretrustee/management"
|
||||
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -52,34 +49,8 @@ var (
|
||||
var opts []grpc.ServerOption
|
||||
|
||||
if mgmtLetsencryptDomain != "" {
|
||||
|
||||
certDir := filepath.Join(mgmtDataDir, "letsencrypt")
|
||||
|
||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(certDir, os.ModeDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", mgmtLetsencryptDomain, certDir)
|
||||
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
HostPolicy: autocert.HostWhitelist(mgmtLetsencryptDomain),
|
||||
}
|
||||
tls := &tls.Config{GetCertificate: certManager.GetCertificate}
|
||||
|
||||
credentials := credentials.NewTLS(tls)
|
||||
opts = append(opts, grpc.Creds(credentials))
|
||||
|
||||
// listener to handle Let's encrypt certificate challenge
|
||||
go func() {
|
||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
||||
}
|
||||
}()
|
||||
transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(mgmtDataDir, mgmtLetsencryptDomain))
|
||||
opts = append(opts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
33
cmd/root.go
33
cmd/root.go
@ -1,16 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
@ -79,31 +74,3 @@ func InitLog(logLevel string) {
|
||||
}
|
||||
log.SetLevel(level)
|
||||
}
|
||||
|
||||
func enableLetsEncrypt(datadir string, letsencryptDomain string) credentials.TransportCredentials {
|
||||
certDir := filepath.Join(datadir, "letsencrypt")
|
||||
|
||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(certDir, os.ModeDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
|
||||
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
|
||||
}
|
||||
|
||||
// listener to handle Let's encrypt certificate challenge
|
||||
go func() {
|
||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return credentials.NewTLS(&tls.Config{GetCertificate: certManager.GetCertificate})
|
||||
}
|
||||
|
@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
sig "github.com/wiretrustee/wiretrustee/signal"
|
||||
sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"net"
|
||||
"os"
|
||||
@ -45,8 +47,8 @@ var (
|
||||
}
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
if mgmtLetsencryptDomain != "" {
|
||||
transportCredentials := enableLetsEncrypt(signalDataDir, signalLetsencryptDomain)
|
||||
if signalLetsencryptDomain != "" {
|
||||
transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(signalDataDir, signalLetsencryptDomain))
|
||||
opts = append(opts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
package signal
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@ -7,9 +7,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// As set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service.
|
||||
// We want to make sure that the Connection Candidates and other irrelevant (to the Signal Exchange)
|
||||
// information can't be read anywhere else but the Peer the message is being sent to.
|
||||
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||
// Wireguard keys are used for encryption
|
||||
|
13
encryption/encryption_suite_test.go
Normal file
13
encryption/encryption_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package encryption_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManagement(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Management Service Suite")
|
||||
}
|
60
encryption/encryption_test.go
Normal file
60
encryption/encryption_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package encryption_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/encryption/testprotos"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const ()
|
||||
|
||||
var _ = Describe("Encryption", func() {
|
||||
|
||||
var (
|
||||
encryptionKey wgtypes.Key
|
||||
decryptionKey wgtypes.Key
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
encryptionKey, err = wgtypes.GenerateKey()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
decryptionKey, err = wgtypes.GenerateKey()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("decrypting a plain message", func() {
|
||||
Context("when it was encrypted with Wireguard keys", func() {
|
||||
Specify("should be successful", func() {
|
||||
msg := "message"
|
||||
encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(string(decryptedMsg)).To(BeEquivalentTo(msg))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("decrypting a protobuf message", func() {
|
||||
Context("when it was encrypted with Wireguard keys", func() {
|
||||
Specify("should be successful", func() {
|
||||
|
||||
protoMsg := &testprotos.TestMessage{Body: "message"}
|
||||
encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
decryptedMsg := &testprotos.TestMessage{}
|
||||
err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
})
|
40
encryption/letsencrypt.go
Normal file
40
encryption/letsencrypt.go
Normal file
@ -0,0 +1,40 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate.
|
||||
// Includes a HTTP handler and listener to solve the Let's encrypt challenge
|
||||
func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config {
|
||||
certDir := filepath.Join(datadir, "letsencrypt")
|
||||
|
||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(certDir, os.ModeDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
|
||||
|
||||
certManager := autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
|
||||
}
|
||||
|
||||
// listener to handle Let's encrypt certificate challenge
|
||||
go func() {
|
||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &tls.Config{GetCertificate: certManager.GetCertificate}
|
||||
}
|
40
encryption/message.go
Normal file
40
encryption/message.go
Normal file
@ -0,0 +1,40 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// EncryptMessage encrypts a body of the given protobuf Message
|
||||
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
|
||||
byteResp, err := pb.Marshal(message)
|
||||
if err != nil {
|
||||
log.Errorf("failed marshalling message %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting SyncResponse %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return encryptedBytes, nil
|
||||
}
|
||||
|
||||
// DecryptMessage decrypts an encrypted message into given protobuf Message
|
||||
func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error {
|
||||
decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey)
|
||||
if err != nil {
|
||||
log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String())
|
||||
return err
|
||||
}
|
||||
|
||||
err = pb.Unmarshal(decrypted, message)
|
||||
if err != nil {
|
||||
log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
2
encryption/testprotos/generate.sh
Executable file
2
encryption/testprotos/generate.sh
Executable file
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
protoc -I testprotos/ testprotos/testproto.proto --go_out=.
|
142
encryption/testprotos/testproto.pb.go
Normal file
142
encryption/testprotos/testproto.pb.go
Normal file
@ -0,0 +1,142 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.12.4
|
||||
// source: testproto.proto
|
||||
|
||||
package testprotos
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type TestMessage struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (x *TestMessage) Reset() {
|
||||
*x = TestMessage{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *TestMessage) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*TestMessage) ProtoMessage() {}
|
||||
|
||||
func (x *TestMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead.
|
||||
func (*TestMessage) Descriptor() ([]byte, []int) {
|
||||
return file_testproto_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *TestMessage) GetBody() string {
|
||||
if x != nil {
|
||||
return x.Body
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_testproto_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_testproto_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
|
||||
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
|
||||
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_testproto_proto_rawDescOnce sync.Once
|
||||
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_testproto_proto_rawDescGZIP() []byte {
|
||||
file_testproto_proto_rawDescOnce.Do(func() {
|
||||
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
|
||||
})
|
||||
return file_testproto_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_testproto_proto_goTypes = []interface{}{
|
||||
(*TestMessage)(nil), // 0: testprotos.TestMessage
|
||||
}
|
||||
var file_testproto_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_testproto_proto_init() }
|
||||
func file_testproto_proto_init() {
|
||||
if File_testproto_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*TestMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_testproto_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_testproto_proto_goTypes,
|
||||
DependencyIndexes: file_testproto_proto_depIdxs,
|
||||
MessageInfos: file_testproto_proto_msgTypes,
|
||||
}.Build()
|
||||
File_testproto_proto = out.File
|
||||
file_testproto_proto_rawDesc = nil
|
||||
file_testproto_proto_goTypes = nil
|
||||
file_testproto_proto_depIdxs = nil
|
||||
}
|
9
encryption/testprotos/testproto.proto
Normal file
9
encryption/testprotos/testproto.proto
Normal file
@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "/testprotos";
|
||||
|
||||
package testprotos;
|
||||
|
||||
message TestMessage {
|
||||
string body = 1;
|
||||
}
|
@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/signal"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
@ -94,7 +94,7 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key)
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
||||
@ -106,7 +106,7 @@ var _ = Describe("Management service", func() {
|
||||
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(encryptedResponse)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
@ -127,7 +127,7 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key)
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
||||
@ -140,7 +140,7 @@ var _ = Describe("Management service", func() {
|
||||
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(encryptedResponse)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
err = pb.Unmarshal(decryptedBytes, resp)
|
||||
@ -153,7 +153,7 @@ var _ = Describe("Management service", func() {
|
||||
go func() {
|
||||
err = sync.RecvMsg(encryptedResponse)
|
||||
|
||||
decryptedBytes, err = signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
resp = &mgmtProto.SyncResponse{}
|
||||
err = pb.Unmarshal(decryptedBytes, resp)
|
||||
@ -240,7 +240,7 @@ var _ = Describe("Management service", func() {
|
||||
for _, peer := range peers {
|
||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, peer)
|
||||
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// receive stream
|
||||
@ -261,7 +261,7 @@ var _ = Describe("Management service", func() {
|
||||
} else if err != nil {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, peer)
|
||||
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
resp := &mgmtProto.SyncResponse{}
|
||||
|
@ -1,44 +0,0 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"github.com/wiretrustee/wiretrustee/signal"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// EncryptMessage encrypts a body of the given pn.Message and wraps into proto.EncryptedMessage
|
||||
func EncryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, message pb.Message) (*proto.EncryptedMessage, error) {
|
||||
byteResp, err := pb.Marshal(message)
|
||||
if err != nil {
|
||||
log.Errorf("failed marshalling message %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBytes, err := signal.Encrypt(byteResp, peerKey, serverPrivateKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting SyncResponse %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: serverPrivateKey.PublicKey().String(),
|
||||
Body: encryptedBytes}, nil
|
||||
}
|
||||
|
||||
//DecryptMessage decrypts an encrypted message (proto.EncryptedMessage)
|
||||
func DecryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, encryptedMessage *proto.EncryptedMessage, message pb.Message) error {
|
||||
decrypted, err := signal.Decrypt(encryptedMessage.Body, peerKey, serverPrivateKey)
|
||||
if err != nil {
|
||||
log.Warnf("error while decrypting Sync request message from peer %s", peerKey.String())
|
||||
return err
|
||||
}
|
||||
|
||||
err = pb.Unmarshal(decrypted, message)
|
||||
if err != nil {
|
||||
log.Warnf("error while umarshalling Sync request message from peer %s", peerKey.String())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -76,7 +77,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
}
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
err = DecryptMessage(peerKey, s.wgKey, req, syncReq)
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
}
|
||||
@ -99,12 +100,15 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
}
|
||||
log.Debugf("recevied an update for peer %s", peerKey.String())
|
||||
|
||||
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
|
||||
err = srv.SendMsg(encryptedResp)
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
@ -200,12 +204,15 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, srv proto.ManagementServic
|
||||
Peers: peers,
|
||||
}
|
||||
|
||||
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
err = srv.Send(encryptedResp)
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed sending SyncResponse %v", err)
|
||||
|
@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
@ -162,12 +162,9 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decryptedBody, err := Decrypt(msg.GetBody(), remoteKey, c.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := &proto.Body{}
|
||||
err = pb.Unmarshal(decryptedBody, body)
|
||||
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -181,16 +178,13 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
|
||||
|
||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||
body, err := pb.Marshal(msg.GetBody())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBody, err := Encrypt(body, remoteKey, c.key)
|
||||
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package signal
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"testing"
|
||||
)
|
||||
@ -21,13 +22,13 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
encryptedMessage, err := Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey)
|
||||
encryptedMessage, err := encryption.Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
decryptedMessage, err := Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey)
|
||||
decryptedMessage, err := encryption.Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user