netbird/management/client/client_test.go

505 lines
13 KiB
Go

package client
import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/client/system"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
)
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
testDir := t.TempDir()
config := &mgmt.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
config.Datadir = testDir
err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
if err != nil {
t.Fatal(err)
}
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis
}
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
response := &mgmtProto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
}
return response, nil
},
}
mgmtProto.RegisterManagementServiceServer(s, mgmtMockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis, mgmtMockServer, serverKey
}
func closeManagementSilently(s *grpc.Server, listener net.Listener) {
s.GracefulStop()
err := listener.Close()
if err != nil {
log.Warnf("error while closing management listener %v", err)
return
}
}
func TestClient_GetServerPublicKey(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
}
}
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
t.Errorf("expecting err code %d denied on unregistered login got %d", codes.PermissionDenied, s.Code())
}
}
func TestClient_LoginRegistered(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
if err != nil {
t.Error(err)
}
// create and register second peer (we should receive on Sync request)
remoteKey, err := wgtypes.GenerateKey()
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
if err != nil {
t.Fatal(err)
}
ch := make(chan *mgmtProto.SyncResponse, 1)
go func() {
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
if err != nil {
return
}
}()
select {
case resp := <-ch:
if resp.GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil")
}
if resp.GetWiretrusteeConfig() == nil {
t.Error("expecting non nil WiretrusteeConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")
}
}
func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
wg.Add(1)
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &mgmtProto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &mgmtProto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
wg.Wait()
protoNetAddr := make([]*mgmtProto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
protoNetAddr = append(protoNetAddr, &mgmtProto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
expectedMeta := &mgmtProto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
Kernel: info.Kernel,
Platform: info.Platform,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
WiretrusteeVersion: info.WiretrusteeVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: protoNetAddr,
SysSerialNumber: info.SystemSerialNumber,
SysProductName: info.SystemProductName,
SysManufacturer: info.SystemManufacturer,
Environment: &mgmtProto.Environment{Cloud: info.Environment.Cloud, Platform: info.Environment.Platform},
}
assert.Equal(t, ValidKey, actualValidKey)
if !isEqual(expectedMeta, actualMeta) {
t.Errorf("expected and actual meta are not equal")
}
}
func isEqual(a, b *mgmtProto.PeerSystemMeta) bool {
if len(a.NetworkAddresses) != len(b.NetworkAddresses) {
return false
}
for _, addr := range a.GetNetworkAddresses() {
var found bool
for _, oAddr := range b.GetNetworkAddresses() {
if addr.GetMac() == oAddr.GetMac() && addr.GetNetIP() == oAddr.GetNetIP() {
found = true
continue
}
}
if !found {
return false
}
}
log.Infof("------")
return a.GetHostname() == b.GetHostname() &&
a.GetGoOS() == b.GetGoOS() &&
a.GetKernel() == b.GetKernel() &&
a.GetKernelVersion() == b.GetKernelVersion() &&
a.GetCore() == b.GetCore() &&
a.GetPlatform() == b.GetPlatform() &&
a.GetOS() == b.GetOS() &&
a.GetOSVersion() == b.GetOSVersion() &&
a.GetWiretrusteeVersion() == b.GetWiretrusteeVersion() &&
a.GetUiVersion() == b.GetUiVersion() &&
a.GetSysSerialNumber() == b.GetSysSerialNumber() &&
a.GetSysProductName() == b.GetSysProductName() &&
a.GetSysManufacturer() == b.GetSysManufacturer() &&
a.GetEnvironment().Cloud == b.GetEnvironment().Cloud &&
a.GetEnvironment().Platform == b.GetEnvironment().Platform
}
func Test_GetDeviceAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"},
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")
}
assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
func Test_GetPKCEAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "client",
ClientSecret: "secret",
},
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}