Extract common server encryption logic (#65)

* refactor: extract common message encryption logic
* refactor: move letsencrypt logic to common
* refactor: rename common package to encryption
* test: add encryption tests
This commit is contained in:
Mikhail Bragin 2021-07-22 15:23:24 +02:00 committed by GitHub
parent c98be683bf
commit 2172d6f1b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 343 additions and 141 deletions

View File

@ -1,21 +1,18 @@
package cmd package cmd
import ( import (
"crypto/tls"
"flag" "flag"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wiretrustee/wiretrustee/encryption"
mgmt "github.com/wiretrustee/wiretrustee/management" mgmt "github.com/wiretrustee/wiretrustee/management"
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto" mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
"golang.org/x/crypto/acme/autocert"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"net" "net"
"net/http"
"os" "os"
"path/filepath"
"time" "time"
) )
@ -52,34 +49,8 @@ var (
var opts []grpc.ServerOption var opts []grpc.ServerOption
if mgmtLetsencryptDomain != "" { if mgmtLetsencryptDomain != "" {
transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(mgmtDataDir, mgmtLetsencryptDomain))
certDir := filepath.Join(mgmtDataDir, "letsencrypt") opts = append(opts, grpc.Creds(transportCredentials))
if _, err := os.Stat(certDir); os.IsNotExist(err) {
err = os.MkdirAll(certDir, os.ModeDir)
if err != nil {
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
}
}
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", mgmtLetsencryptDomain, certDir)
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
HostPolicy: autocert.HostWhitelist(mgmtLetsencryptDomain),
}
tls := &tls.Config{GetCertificate: certManager.GetCertificate}
credentials := credentials.NewTLS(tls)
opts = append(opts, grpc.Creds(credentials))
// listener to handle Let's encrypt certificate challenge
go func() {
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
log.Fatalf("failed to serve letsencrypt handler: %v", err)
}
}()
} }
opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) opts = append(opts, grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))

View File

@ -1,16 +1,11 @@
package cmd package cmd
import ( import (
"crypto/tls"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"google.golang.org/grpc/credentials"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"runtime" "runtime"
) )
@ -79,31 +74,3 @@ func InitLog(logLevel string) {
} }
log.SetLevel(level) log.SetLevel(level)
} }
func enableLetsEncrypt(datadir string, letsencryptDomain string) credentials.TransportCredentials {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
err = os.MkdirAll(certDir, os.ModeDir)
if err != nil {
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
}
}
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
}
// listener to handle Let's encrypt certificate challenge
go func() {
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
log.Fatalf("failed to serve letsencrypt handler: %v", err)
}
}()
return credentials.NewTLS(&tls.Config{GetCertificate: certManager.GetCertificate})
}

View File

@ -5,9 +5,11 @@ import (
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wiretrustee/wiretrustee/encryption"
sig "github.com/wiretrustee/wiretrustee/signal" sig "github.com/wiretrustee/wiretrustee/signal"
sigProto "github.com/wiretrustee/wiretrustee/signal/proto" sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"net" "net"
"os" "os"
@ -45,8 +47,8 @@ var (
} }
var opts []grpc.ServerOption var opts []grpc.ServerOption
if mgmtLetsencryptDomain != "" { if signalLetsencryptDomain != "" {
transportCredentials := enableLetsEncrypt(signalDataDir, signalLetsencryptDomain) transportCredentials := credentials.NewTLS(encryption.EnableLetsEncrypt(signalDataDir, signalLetsencryptDomain))
opts = append(opts, grpc.Creds(transportCredentials)) opts = append(opts, grpc.Creds(transportCredentials))
} }

View File

@ -1,4 +1,4 @@
package signal package encryption
import ( import (
"crypto/rand" "crypto/rand"
@ -7,9 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// As set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service. // A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
// We want to make sure that the Connection Candidates and other irrelevant (to the Signal Exchange)
// information can't be read anywhere else but the Peer the message is being sent to.
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate) // These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
// Wireguard keys are used for encryption // Wireguard keys are used for encryption

View File

@ -0,0 +1,13 @@
package encryption_test
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestManagement(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Management Service Suite")
}

View File

@ -0,0 +1,60 @@
package encryption_test
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/encryption/testprotos"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const ()
var _ = Describe("Encryption", func() {
var (
encryptionKey wgtypes.Key
decryptionKey wgtypes.Key
)
BeforeEach(func() {
var err error
encryptionKey, err = wgtypes.GenerateKey()
Expect(err).NotTo(HaveOccurred())
decryptionKey, err = wgtypes.GenerateKey()
Expect(err).NotTo(HaveOccurred())
})
Context("decrypting a plain message", func() {
Context("when it was encrypted with Wireguard keys", func() {
Specify("should be successful", func() {
msg := "message"
encryptedMsg, err := encryption.Encrypt([]byte(msg), decryptionKey.PublicKey(), encryptionKey)
Expect(err).NotTo(HaveOccurred())
decryptedMsg, err := encryption.Decrypt(encryptedMsg, encryptionKey.PublicKey(), decryptionKey)
Expect(err).NotTo(HaveOccurred())
Expect(string(decryptedMsg)).To(BeEquivalentTo(msg))
})
})
})
Context("decrypting a protobuf message", func() {
Context("when it was encrypted with Wireguard keys", func() {
Specify("should be successful", func() {
protoMsg := &testprotos.TestMessage{Body: "message"}
encryptedMsg, err := encryption.EncryptMessage(decryptionKey.PublicKey(), encryptionKey, protoMsg)
Expect(err).NotTo(HaveOccurred())
decryptedMsg := &testprotos.TestMessage{}
err = encryption.DecryptMessage(encryptionKey.PublicKey(), decryptionKey, encryptedMsg, decryptedMsg)
Expect(err).NotTo(HaveOccurred())
Expect(decryptedMsg.GetBody()).To(BeEquivalentTo(protoMsg.GetBody()))
})
})
})
})

40
encryption/letsencrypt.go Normal file
View File

@ -0,0 +1,40 @@
package encryption
import (
"crypto/tls"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
"net/http"
"os"
"path/filepath"
)
// EnableLetsEncrypt wraps common logic of generating Let's encrypt certificate.
// Includes a HTTP handler and listener to solve the Let's encrypt challenge
func EnableLetsEncrypt(datadir string, letsencryptDomain string) *tls.Config {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
err = os.MkdirAll(certDir, os.ModeDir)
if err != nil {
log.Fatalf("failed creating Let's encrypt certdir: %s: %v", certDir, err)
}
}
log.Infof("running with Let's encrypt with domain %s. Cert will be stored in %s", letsencryptDomain, certDir)
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
}
// listener to handle Let's encrypt certificate challenge
go func() {
if err := http.Serve(certManager.Listener(), certManager.HTTPHandler(nil)); err != nil {
log.Fatalf("failed to serve letsencrypt handler: %v", err)
}
}()
return &tls.Config{GetCertificate: certManager.GetCertificate}
}

40
encryption/message.go Normal file
View File

@ -0,0 +1,40 @@
package encryption
import (
pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// EncryptMessage encrypts a body of the given protobuf Message
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
log.Errorf("failed marshalling message %v", err)
return nil, err
}
encryptedBytes, err := Encrypt(byteResp, remotePubKey, ourPrivateKey)
if err != nil {
log.Errorf("failed encrypting SyncResponse %v", err)
return nil, err
}
return encryptedBytes, nil
}
// DecryptMessage decrypts an encrypted message into given protobuf Message
func DecryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, encryptedMessage []byte, message pb.Message) error {
decrypted, err := Decrypt(encryptedMessage, remotePubKey, ourPrivateKey)
if err != nil {
log.Warnf("error while decrypting Sync request message from peer %s", remotePubKey.String())
return err
}
err = pb.Unmarshal(decrypted, message)
if err != nil {
log.Warnf("error while umarshalling Sync request message from peer %s", remotePubKey.String())
return err
}
return nil
}

View File

@ -0,0 +1,2 @@
#!/bin/bash
protoc -I testprotos/ testprotos/testproto.proto --go_out=.

View File

@ -0,0 +1,142 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// source: testproto.proto
package testprotos
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type TestMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
}
func (x *TestMessage) Reset() {
*x = TestMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_testproto_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TestMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TestMessage) ProtoMessage() {}
func (x *TestMessage) ProtoReflect() protoreflect.Message {
mi := &file_testproto_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TestMessage.ProtoReflect.Descriptor instead.
func (*TestMessage) Descriptor() ([]byte, []int) {
return file_testproto_proto_rawDescGZIP(), []int{0}
}
func (x *TestMessage) GetBody() string {
if x != nil {
return x.Body
}
return ""
}
var File_testproto_proto protoreflect.FileDescriptor
var file_testproto_proto_rawDesc = []byte{
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_testproto_proto_rawDescOnce sync.Once
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
)
func file_testproto_proto_rawDescGZIP() []byte {
file_testproto_proto_rawDescOnce.Do(func() {
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
})
return file_testproto_proto_rawDescData
}
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_testproto_proto_goTypes = []interface{}{
(*TestMessage)(nil), // 0: testprotos.TestMessage
}
var file_testproto_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_testproto_proto_init() }
func file_testproto_proto_init() {
if File_testproto_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TestMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_testproto_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_testproto_proto_goTypes,
DependencyIndexes: file_testproto_proto_depIdxs,
MessageInfos: file_testproto_proto_msgTypes,
}.Build()
File_testproto_proto = out.File
file_testproto_proto_rawDesc = nil
file_testproto_proto_goTypes = nil
file_testproto_proto_depIdxs = nil
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
option go_package = "/testprotos";
package testprotos;
message TestMessage {
string body = 1;
}

View File

@ -4,7 +4,7 @@ import (
"context" "context"
pb "github.com/golang/protobuf/proto" //nolint pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/signal" "github.com/wiretrustee/wiretrustee/encryption"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
@ -94,7 +94,7 @@ var _ = Describe("Management service", func() {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
@ -106,7 +106,7 @@ var _ = Describe("Management service", func() {
encryptedResponse := &mgmtProto.EncryptedMessage{} encryptedResponse := &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(encryptedResponse) err = sync.RecvMsg(encryptedResponse)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
resp := &mgmtProto.SyncResponse{} resp := &mgmtProto.SyncResponse{}
@ -127,7 +127,7 @@ var _ = Describe("Management service", func() {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, key) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
@ -140,7 +140,7 @@ var _ = Describe("Management service", func() {
encryptedResponse := &mgmtProto.EncryptedMessage{} encryptedResponse := &mgmtProto.EncryptedMessage{}
err = sync.RecvMsg(encryptedResponse) err = sync.RecvMsg(encryptedResponse)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, key) decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
resp := &mgmtProto.SyncResponse{} resp := &mgmtProto.SyncResponse{}
err = pb.Unmarshal(decryptedBytes, resp) err = pb.Unmarshal(decryptedBytes, resp)
@ -153,7 +153,7 @@ var _ = Describe("Management service", func() {
go func() { go func() {
err = sync.RecvMsg(encryptedResponse) err = sync.RecvMsg(encryptedResponse)
decryptedBytes, err = signal.Decrypt(encryptedResponse.Body, serverPubKey, key) decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
resp = &mgmtProto.SyncResponse{} resp = &mgmtProto.SyncResponse{}
err = pb.Unmarshal(decryptedBytes, resp) err = pb.Unmarshal(decryptedBytes, resp)
@ -240,7 +240,7 @@ var _ = Describe("Management service", func() {
for _, peer := range peers { for _, peer := range peers {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{}) messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := signal.Encrypt(messageBytes, serverPubKey, peer) encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// receive stream // receive stream
@ -261,7 +261,7 @@ var _ = Describe("Management service", func() {
} else if err != nil { } else if err != nil {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
decryptedBytes, err := signal.Decrypt(encryptedResponse.Body, serverPubKey, peer) decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
resp := &mgmtProto.SyncResponse{} resp := &mgmtProto.SyncResponse{}

View File

@ -1,44 +0,0 @@
package management
import (
pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/management/proto"
"github.com/wiretrustee/wiretrustee/signal"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// EncryptMessage encrypts a body of the given pn.Message and wraps into proto.EncryptedMessage
func EncryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, message pb.Message) (*proto.EncryptedMessage, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
log.Errorf("failed marshalling message %v", err)
return nil, err
}
encryptedBytes, err := signal.Encrypt(byteResp, peerKey, serverPrivateKey)
if err != nil {
log.Errorf("failed encrypting SyncResponse %v", err)
return nil, err
}
return &proto.EncryptedMessage{
WgPubKey: serverPrivateKey.PublicKey().String(),
Body: encryptedBytes}, nil
}
//DecryptMessage decrypts an encrypted message (proto.EncryptedMessage)
func DecryptMessage(peerKey wgtypes.Key, serverPrivateKey wgtypes.Key, encryptedMessage *proto.EncryptedMessage, message pb.Message) error {
decrypted, err := signal.Decrypt(encryptedMessage.Body, peerKey, serverPrivateKey)
if err != nil {
log.Warnf("error while decrypting Sync request message from peer %s", peerKey.String())
return err
}
err = pb.Unmarshal(decrypted, message)
if err != nil {
log.Warnf("error while umarshalling Sync request message from peer %s", peerKey.String())
return err
}
return nil
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/management/proto" "github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -76,7 +77,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
} }
syncReq := &proto.SyncRequest{} syncReq := &proto.SyncRequest{}
err = DecryptMessage(peerKey, s.wgKey, req, syncReq) err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
if err != nil { if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid request message") return status.Errorf(codes.InvalidArgument, "invalid request message")
} }
@ -99,12 +100,15 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
} }
log.Debugf("recevied an update for peer %s", peerKey.String()) log.Debugf("recevied an update for peer %s", peerKey.String())
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, update.Update) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed processing update message") return status.Errorf(codes.Internal, "failed processing update message")
} }
err = srv.SendMsg(encryptedResp) err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed sending update message") return status.Errorf(codes.Internal, "failed sending update message")
} }
@ -200,12 +204,15 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, srv proto.ManagementServic
Peers: peers, Peers: peers,
} }
encryptedResp, err := EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
err = srv.Send(encryptedResp) err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil { if err != nil {
log.Errorf("failed sending SyncResponse %v", err) log.Errorf("failed sending SyncResponse %v", err)

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/signal/proto" "github.com/wiretrustee/wiretrustee/signal/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -162,12 +162,9 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
decryptedBody, err := Decrypt(msg.GetBody(), remoteKey, c.key)
if err != nil {
return nil, err
}
body := &proto.Body{} body := &proto.Body{}
err = pb.Unmarshal(decryptedBody, body) err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -181,16 +178,13 @@ func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, er
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key // encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) { func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
body, err := pb.Marshal(msg.GetBody())
if err != nil {
return nil, err
}
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey) remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
encryptedBody, err := Encrypt(body, remoteKey, c.key) encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package signal package signal
import ( import (
"github.com/wiretrustee/wiretrustee/encryption"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"testing" "testing"
) )
@ -21,13 +22,13 @@ func TestEncryptDecrypt(t *testing.T) {
return return
} }
encryptedMessage, err := Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey) encryptedMessage, err := encryption.Encrypt(bytesMsg, peerBKey.PublicKey(), peerAKey)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
decryptedMessage, err := Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey) decryptedMessage, err := encryption.Decrypt(encryptedMessage, peerAKey.PublicKey(), peerBKey)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return