mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 00:54:01 +01:00
Extract common server encryption logic (#65)
* refactor: extract common message encryption logic * refactor: move letsencrypt logic to common * refactor: rename common package to encryption * test: add encryption tests
This commit is contained in:
parent
c98be683bf
commit
2172d6f1b9
@ -1,21 +1,18 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
mgmt "github.com/wiretrustee/wiretrustee/management"
|
mgmt "github.com/wiretrustee/wiretrustee/management"
|
||||||
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
|
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,34 +49,8 @@ var (
|
|||||||
var opts []grpc.ServerOption
|
var opts []grpc.ServerOption
|
||||||
|
|
||||||
if mgmtLetsencryptDomain != "" {
|
if mgmtLetsencryptDomain != "" {
|
||||||
|
transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(mgmtDataDir, mgmtLetsencryptDomain))
|
||||||
certDir := filepath.Join(mgmtDataDir, "letsencrypt")
|
opts = append(opts, grpc.Creds(transportCredentials))
|
||||||
|
|
||||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
|
||||||
err = os.MkdirAll(certDir, os.ModeDir)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", mgmtLetsencryptDomain, certDir)
|
|
||||||
|
|
||||||
certManager := autocert.Manager{
|
|
||||||
Prompt: autocert.AcceptTOS,
|
|
||||||
Cache: autocert.DirCache(certDir),
|
|
||||||
HostPolicy: autocert.HostWhitelist(mgmtLetsencryptDomain),
|
|
||||||
}
|
|
||||||
tls := &tls.Config{GetCertificate: certManager.GetCertificate}
|
|
||||||
|
|
||||||
credentials := credentials.NewTLS(tls)
|
|
||||||
opts = append(opts, grpc.Creds(credentials))
|
|
||||||
|
|
||||||
// listener to handle Let's encrypt certificate challenge
|
|
||||||
go func() {
|
|
||||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
|
||||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
33
cmd/root.go
33
cmd/root.go
@ -1,16 +1,11 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -79,31 +74,3 @@ func InitLog(logLevel string) {
|
|||||||
}
|
}
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableLetsEncrypt(datadir string, letsencryptDomain string) credentials.TransportCredentials {
|
|
||||||
certDir := filepath.Join(datadir, "letsencrypt")
|
|
||||||
|
|
||||||
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
|
||||||
err = os.MkdirAll(certDir, os.ModeDir)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
|
|
||||||
|
|
||||||
certManager := autocert.Manager{
|
|
||||||
Prompt: autocert.AcceptTOS,
|
|
||||||
Cache: autocert.DirCache(certDir),
|
|
||||||
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
|
|
||||||
}
|
|
||||||
|
|
||||||
// listener to handle Let's encrypt certificate challenge
|
|
||||||
go func() {
|
|
||||||
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
|
||||||
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return credentials.NewTLS(&tls.Config{GetCertificate: certManager.GetCertificate})
|
|
||||||
}
|
|
||||||
|
@ -5,9 +5,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
sig "github.com/wiretrustee/wiretrustee/signal"
|
sig "github.com/wiretrustee/wiretrustee/signal"
|
||||||
sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -45,8 +47,8 @@ var (
|
|||||||
}
|
}
|
||||||
|
|
||||||
var opts []grpc.ServerOption
|
var opts []grpc.ServerOption
|
||||||
if mgmtLetsencryptDomain != "" {
|
if signalLetsencryptDomain != "" {
|
||||||
transportCredentials := enableLetsEncrypt(signalDataDir, signalLetsencryptDomain)
|
transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(signalDataDir, signalLetsencryptDomain))
|
||||||
opts = append(opts, grpc.Creds(transportCredentials))
|
opts = append(opts, grpc.Creds(transportCredentials))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package signal
|
package encryption
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@ -7,9 +7,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
// As set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service.
|
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||||
// We want to make sure that the Connection Candidates and other irrelevant (to the Signal Exchange)
|
|
||||||
// information can't be read anywhere else but the Peer the message is being sent to.
|
|
||||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||||
// Wireguard keys are used for encryption
|
// Wireguard keys are used for encryption
|
||||||
|
|
13
encryption/encryption_suite_test.go
Normal file
13
encryption/encryption_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package encryption_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManagement(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Management Service Suite")
|
||||||
|
}
|
60
encryption/encryption_test.go
Normal file
60
encryption/encryption_test.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package encryption_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption/testprotos"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ()
|
||||||
|
|
||||||
|
var _ = Describe("Encryption", func() {
|
||||||
|
|
||||||
|
var (
|
||||||
|
encryptionKey wgtypes.Key
|
||||||
|
decryptionKey wgtypes.Key
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
encryptionKey, err = wgtypes.GenerateKey()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
decryptionKey, err = wgtypes.GenerateKey()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("decrypting a plain message", func() {
|
||||||
|
Context("when it was encrypted with Wireguard keys", func() {
|
||||||
|
Specify("should be successful", func() {
|
||||||
|
msg := "message"
|
||||||
|
encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(string(decryptedMsg)).To(BeEquivalentTo(msg))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("decrypting a protobuf message", func() {
|
||||||
|
Context("when it was encrypted with Wireguard keys", func() {
|
||||||
|
Specify("should be successful", func() {
|
||||||
|
|
||||||
|
protoMsg := &testprotos.TestMessage{Body: "message"}
|
||||||
|
encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
decryptedMsg := &testprotos.TestMessage{}
|
||||||
|
err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody()))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
40
encryption/letsencrypt.go
Normal file
40
encryption/letsencrypt.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate.
|
||||||
|
// Includes a HTTP handler and listener to solve the Let's encrypt challenge
|
||||||
|
func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config {
|
||||||
|
certDir := filepath.Join(datadir, "letsencrypt")
|
||||||
|
|
||||||
|
if _, err := os.Stat(certDir); os.IsNotExist(err) {
|
||||||
|
err = os.MkdirAll(certDir, os.ModeDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
|
||||||
|
|
||||||
|
certManager := autocert.Manager{
|
||||||
|
Prompt: autocert.AcceptTOS,
|
||||||
|
Cache: autocert.DirCache(certDir),
|
||||||
|
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
|
||||||
|
}
|
||||||
|
|
||||||
|
// listener to handle Let's encrypt certificate challenge
|
||||||
|
go func() {
|
||||||
|
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
|
||||||
|
log.Fatalf("failed to serve letsencrypt handler: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &tls.Config{GetCertificate: certManager.GetCertificate}
|
||||||
|
}
|
40
encryption/message.go
Normal file
40
encryption/message.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
pb "github.com/golang/protobuf/proto" //nolint
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncryptMessage encrypts a body of the given protobuf Message
|
||||||
|
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
|
||||||
|
byteResp, err := pb.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed marshalling message %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed encrypting SyncResponse %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptMessage decrypts an encrypted message into given protobuf Message
|
||||||
|
func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error {
|
||||||
|
decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pb.Unmarshal(decrypted, message)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
2
encryption/testprotos/generate.sh
Executable file
2
encryption/testprotos/generate.sh
Executable file
@ -0,0 +1,2 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
protoc -I testprotos/ testprotos/testproto.proto --go_out=.
|
142
encryption/testprotos/testproto.pb.go
Normal file
142
encryption/testprotos/testproto.pb.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.26.0
|
||||||
|
// protoc v3.12.4
|
||||||
|
// source: testproto.proto
|
||||||
|
|
||||||
|
package testprotos
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestMessage struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TestMessage) Reset() {
|
||||||
|
*x = TestMessage{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_testproto_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TestMessage) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*TestMessage) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *TestMessage) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_testproto_proto_msgTypes[0]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead.
|
||||||
|
func (*TestMessage) Descriptor() ([]byte, []int) {
|
||||||
|
return file_testproto_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *TestMessage) GetBody() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Body
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_testproto_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
var file_testproto_proto_rawDesc = []byte{
|
||||||
|
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||||
|
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
|
||||||
|
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
|
||||||
|
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
|
||||||
|
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
|
||||||
|
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_testproto_proto_rawDescOnce sync.Once
|
||||||
|
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_testproto_proto_rawDescGZIP() []byte {
|
||||||
|
file_testproto_proto_rawDescOnce.Do(func() {
|
||||||
|
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
|
||||||
|
})
|
||||||
|
return file_testproto_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||||
|
var file_testproto_proto_goTypes = []interface{}{
|
||||||
|
(*TestMessage)(nil), // 0: testprotos.TestMessage
|
||||||
|
}
|
||||||
|
var file_testproto_proto_depIdxs = []int32{
|
||||||
|
0, // [0:0] is the sub-list for method output_type
|
||||||
|
0, // [0:0] is the sub-list for method input_type
|
||||||
|
0, // [0:0] is the sub-list for extension type_name
|
||||||
|
0, // [0:0] is the sub-list for extension extendee
|
||||||
|
0, // [0:0] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_testproto_proto_init() }
|
||||||
|
func file_testproto_proto_init() {
|
||||||
|
if File_testproto_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !protoimpl.UnsafeEnabled {
|
||||||
|
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*TestMessage); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: file_testproto_proto_rawDesc,
|
||||||
|
NumEnums: 0,
|
||||||
|
NumMessages: 1,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 0,
|
||||||
|
},
|
||||||
|
GoTypes: file_testproto_proto_goTypes,
|
||||||
|
DependencyIndexes: file_testproto_proto_depIdxs,
|
||||||
|
MessageInfos: file_testproto_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_testproto_proto = out.File
|
||||||
|
file_testproto_proto_rawDesc = nil
|
||||||
|
file_testproto_proto_goTypes = nil
|
||||||
|
file_testproto_proto_depIdxs = nil
|
||||||
|
}
|
9
encryption/testprotos/testproto.proto
Normal file
9
encryption/testprotos/testproto.proto
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
option go_package = "/testprotos";
|
||||||
|
|
||||||
|
package testprotos;
|
||||||
|
|
||||||
|
message TestMessage {
|
||||||
|
string body = 1;
|
||||||
|
}
|
@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
pb "github.com/golang/protobuf/proto" //nolint
|
pb "github.com/golang/protobuf/proto" //nolint
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/wiretrustee/wiretrustee/signal"
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -94,7 +94,7 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key)
|
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
||||||
@ -106,7 +106,7 @@ var _ = Describe("Management service", func() {
|
|||||||
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
||||||
err = sync.RecvMsg(encryptedResponse)
|
err = sync.RecvMsg(encryptedResponse)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
resp := &mgmtProto.SyncResponse{}
|
resp := &mgmtProto.SyncResponse{}
|
||||||
@ -127,7 +127,7 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key)
|
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
|
||||||
@ -140,7 +140,7 @@ var _ = Describe("Management service", func() {
|
|||||||
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
encryptedResponse := &mgmtProto.EncryptedMessage{}
|
||||||
err = sync.RecvMsg(encryptedResponse)
|
err = sync.RecvMsg(encryptedResponse)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
resp := &mgmtProto.SyncResponse{}
|
resp := &mgmtProto.SyncResponse{}
|
||||||
err = pb.Unmarshal(decryptedBytes, resp)
|
err = pb.Unmarshal(decryptedBytes, resp)
|
||||||
@ -153,7 +153,7 @@ var _ = Describe("Management service", func() {
|
|||||||
go func() {
|
go func() {
|
||||||
err = sync.RecvMsg(encryptedResponse)
|
err = sync.RecvMsg(encryptedResponse)
|
||||||
|
|
||||||
decryptedBytes, err = signal.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
resp = &mgmtProto.SyncResponse{}
|
resp = &mgmtProto.SyncResponse{}
|
||||||
err = pb.Unmarshal(decryptedBytes, resp)
|
err = pb.Unmarshal(decryptedBytes, resp)
|
||||||
@ -240,7 +240,7 @@ var _ = Describe("Management service", func() {
|
|||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, peer)
|
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
// receive stream
|
// receive stream
|
||||||
@ -261,7 +261,7 @@ var _ = Describe("Management service", func() {
|
|||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
}
|
}
|
||||||
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, peer)
|
decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
resp := &mgmtProto.SyncResponse{}
|
resp := &mgmtProto.SyncResponse{}
|
||||||
|
@ -1,44 +0,0 @@
|
|||||||
package management
|
|
||||||
|
|
||||||
import (
|
|
||||||
pb "github.com/golang/protobuf/proto" //nolint
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
|
||||||
"github.com/wiretrustee/wiretrustee/signal"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EncryptMessage encrypts a body of the given pn.Message and wraps into proto.EncryptedMessage
|
|
||||||
func EncryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, message pb.Message) (*proto.EncryptedMessage, error) {
|
|
||||||
byteResp, err := pb.Marshal(message)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed marshalling message %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
encryptedBytes, err := signal.Encrypt(byteResp, peerKey, serverPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed encrypting SyncResponse %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.EncryptedMessage{
|
|
||||||
WgPubKey: serverPrivateKey.PublicKey().String(),
|
|
||||||
Body: encryptedBytes}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//DecryptMessage decrypts an encrypted message (proto.EncryptedMessage)
|
|
||||||
func DecryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, encryptedMessage *proto.EncryptedMessage, message pb.Message) error {
|
|
||||||
decrypted, err := signal.Decrypt(encryptedMessage.Body, peerKey, serverPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error while decrypting Sync request message from peer %s", peerKey.String())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pb.Unmarshal(decrypted, message)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error while umarshalling Sync request message from peer %s", peerKey.String())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -76,7 +77,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
}
|
}
|
||||||
|
|
||||||
syncReq := &proto.SyncRequest{}
|
syncReq := &proto.SyncRequest{}
|
||||||
err = DecryptMessage(peerKey, s.wgKey, req, syncReq)
|
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid request message")
|
return status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||||
}
|
}
|
||||||
@ -99,12 +100,15 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
}
|
}
|
||||||
log.Debugf("recevied an update for peer %s", peerKey.String())
|
log.Debugf("recevied an update for peer %s", peerKey.String())
|
||||||
|
|
||||||
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, update.Update)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed processing update message")
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = srv.SendMsg(encryptedResp)
|
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed sending update message")
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
}
|
}
|
||||||
@ -200,12 +204,15 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, srv proto.ManagementServic
|
|||||||
Peers: peers,
|
Peers: peers,
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, plainResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = srv.Send(encryptedResp)
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed sending SyncResponse %v", err)
|
log.Errorf("failed sending SyncResponse %v", err)
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
pb "github.com/golang/protobuf/proto" //nolint
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@ -162,12 +162,9 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
decryptedBody, err := Decrypt(msg.GetBody(), remoteKey, c.key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
body := &proto.Body{}
|
body := &proto.Body{}
|
||||||
err = pb.Unmarshal(decryptedBody, body)
|
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -181,16 +178,13 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
|
|||||||
|
|
||||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||||
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||||
body, err := pb.Marshal(msg.GetBody())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedBody, err := Encrypt(body, remoteKey, c.key)
|
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package signal
|
package signal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -21,13 +22,13 @@ func TestEncryptDecrypt(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedMessage, err := Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey)
|
encryptedMessage, err := encryption.Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
decryptedMessage, err := Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey)
|
decryptedMessage, err := encryption.Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user