mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-27 02:24:30 +01:00
168 lines
4.0 KiB
Go
168 lines
4.0 KiB
Go
|
package client
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
mgmtProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||
|
mgmt "github.com/wiretrustee/wiretrustee/management/server"
|
||
|
"github.com/wiretrustee/wiretrustee/util"
|
||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
"net"
|
||
|
"path/filepath"
|
||
|
"testing"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var tested *Client
|
||
|
var serverAddr string
|
||
|
|
||
|
func Test_Start(t *testing.T) {
|
||
|
level, _ := log.ParseLevel("debug")
|
||
|
log.SetLevel(level)
|
||
|
|
||
|
testKey, err := wgtypes.GenerateKey()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
testDir := t.TempDir()
|
||
|
ctx := context.Background()
|
||
|
config := &mgmt.Config{}
|
||
|
_, err = util.ReadJson("../server/testdata/management.json", config)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
config.Datadir = testDir
|
||
|
err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json"))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
_, listener := startManagement(config, t)
|
||
|
serverAddr = listener.Addr().String()
|
||
|
tested, err = NewClient(ctx, serverAddr, testKey, false)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Listener) {
|
||
|
lis, err := net.Listen("tcp", ":0")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
s := grpc.NewServer()
|
||
|
store, err := mgmt.NewStore(config.Datadir)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
accountManager := mgmt.NewManager(store)
|
||
|
mgmtServer, err := mgmt.NewServer(config, accountManager)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||
|
go func() {
|
||
|
if err := s.Serve(lis); err != nil {
|
||
|
t.Error(err)
|
||
|
return
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
return s, lis
|
||
|
}
|
||
|
|
||
|
func TestClient_GetServerPublicKey(t *testing.T) {
|
||
|
|
||
|
key, err := tested.GetServerPublicKey()
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
|
||
|
if key == nil {
|
||
|
t.Error("expecting non nil server key got nil")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
||
|
key, err := tested.GetServerPublicKey()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
_, err = tested.Login(*key)
|
||
|
if err == nil {
|
||
|
t.Error("expecting err on unregistered login, got nil")
|
||
|
}
|
||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
|
||
|
t.Errorf("expecting err code %d denied on on unregistered login got %d", codes.PermissionDenied, s.Code())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestClient_LoginRegistered(t *testing.T) {
|
||
|
key, err := tested.GetServerPublicKey()
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
resp, err := tested.Register(*key, "a2c8e62b-38f5-4553-b31e-dd66c696cebb")
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
|
||
|
if resp == nil {
|
||
|
t.Error("expecting non nil response, got nil")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestClient_Sync(t *testing.T) {
|
||
|
serverKey, err := tested.GetServerPublicKey()
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
|
||
|
_, err = tested.Register(*serverKey, "a2c8e62b-38f5-4553-b31e-dd66c696cebb")
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
|
||
|
// create and register second peer (we should receive on Sync request)
|
||
|
remoteKey, err := wgtypes.GenerateKey()
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
remoteClient, err := NewClient(context.TODO(), serverAddr, remoteKey, false)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
_, err = remoteClient.Register(*serverKey, "a2c8e62b-38f5-4553-b31e-dd66c696cebb")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
ch := make(chan *mgmtProto.SyncResponse, 1)
|
||
|
|
||
|
tested.Sync(func(msg *mgmtProto.SyncResponse) error {
|
||
|
ch <- msg
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
select {
|
||
|
case resp := <-ch:
|
||
|
if resp.GetPeerConfig() == nil {
|
||
|
t.Error("expecting non nil PeerConfig got nil")
|
||
|
}
|
||
|
if resp.GetWiretrusteeConfig() == nil {
|
||
|
t.Error("expecting non nil WiretrusteeConfig got nil")
|
||
|
}
|
||
|
if len(resp.GetRemotePeers()) != 1 {
|
||
|
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||
|
}
|
||
|
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||
|
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
|
||
|
}
|
||
|
case <-time.After(3 * time.Second):
|
||
|
t.Error("timeout waiting for test to finish")
|
||
|
}
|
||
|
}
|