mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 10:31:45 +02:00
Get Device Authorization Flow information from management (#308)
We will configure the device authorization flow information and a client will retrieve it and initiate a device authorization gran flow
This commit is contained in:
@ -6,7 +6,7 @@ import (
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@ -85,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
mport := 33091
|
||||
mgmtServer, err := startManagement(mport, &Config{
|
||||
mgmtServer, err := startManagement(t, mport, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@ -301,7 +301,102 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
|
||||
|
||||
}
|
||||
|
||||
func startManagement(port int, config *Config) (*grpc.Server, error) {
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
config: &Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) {
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
@ -320,9 +415,10 @@ func startManagement(port int, config *Config) (*grpc.Server, error) {
|
||||
return nil, err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
t.Errorf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
Reference in New Issue
Block a user