diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 599d36eab..02c8edea7 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -22,6 +22,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" @@ -49,8 +52,6 @@ import ( "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" ) var ( @@ -1256,7 +1257,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } info := system.GetInfo(ctx) - resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil) + resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) if err != nil { return nil, err } diff --git a/client/internal/login.go b/client/internal/login.go index 092f2309c..395a17199 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -140,7 +140,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. config.DisableDNS, config.DisableFirewall, ) - loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey) + loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { log.Errorf("failed registering peer %v,%s", err, validSetupKey.String()) return nil, err diff --git a/management/client/client.go b/management/client/client.go index e9eeaccc1..950f6137e 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -15,7 +15,7 @@ type Client interface { io.Closer Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKey() (*wgtypes.Key, error) - Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) + Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) diff --git a/management/client/client_test.go b/management/client/client_test.go index 2bf802821..21f6b79ad 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -205,7 +205,7 @@ func TestClient_LoginRegistered(t *testing.T) { t.Error(err) } info := system.GetInfo(context.TODO()) - resp, err := client.Register(*key, ValidKey, "", info, nil) + resp, err := client.Register(*key, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -235,7 +235,7 @@ func TestClient_Sync(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = client.Register(*serverKey, ValidKey, "", info, nil) + _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -251,7 +251,7 @@ func TestClient_Sync(t *testing.T) { } info = system.GetInfo(context.TODO()) - _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil) + _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) if err != nil { t.Fatal(err) } @@ -352,7 +352,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = testClient.Register(*key, ValidKey, "", info, nil) + _, err = testClient.Register(*key, ValidKey, "", info, nil, nil) if err != nil { t.Errorf("error while trying to register client: %v", err) } diff --git a/management/client/grpc.go b/management/client/grpc.go index d02509c27..d3aaffec0 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -365,12 +365,12 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) { +func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys}) + return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. diff --git a/management/client/mock.go b/management/client/mock.go index 11564093a..9e1786f82 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -14,7 +14,7 @@ type MockClient struct { CloseFunc func() error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKeyFunc func() (*wgtypes.Key, error) - RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) + RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) @@ -46,11 +46,11 @@ func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { return m.GetServerPublicKeyFunc() } -func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) { +func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.RegisterFunc == nil { return nil, nil } - return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey) + return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) } func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {