From 8f66dea11c7c72aa9edc6e6d24ece3d46d852c1d Mon Sep 17 00:00:00 2001 From: Bethuel Date: Thu, 23 Mar 2023 16:54:31 +0300 Subject: [PATCH 01/14] Add Keycloak Idp Manager (#746) Added intergration with keycloak user API. --- management/server/idp/idp.go | 10 +- management/server/idp/keycloak.go | 581 +++++++++++++++++++++++++ management/server/idp/keycloak_test.go | 401 +++++++++++++++++ 3 files changed, 989 insertions(+), 3 deletions(-) create mode 100644 management/server/idp/keycloak.go create mode 100644 management/server/idp/keycloak_test.go diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 724a3541d..4dd950369 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -2,10 +2,11 @@ package idp import ( "fmt" - "github.com/netbirdio/netbird/management/server/telemetry" "net/http" "strings" "time" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // Manager idp manager interface @@ -20,8 +21,9 @@ type Manager interface { // Config an idp configuration struct to be loaded from management server's config file type Config struct { - ManagerType string - Auth0ClientCredentials Auth0ClientConfig + ManagerType string + Auth0ClientCredentials Auth0ClientConfig + KeycloakClientCredentials KeycloakClientConfig } // ManagerCredentials interface that authenticates using the credential of each type of idp @@ -71,6 +73,8 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) return nil, nil case "auth0": return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics) + case "keycloak": + return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go new file mode 100644 index 000000000..f9fc94ae7 --- /dev/null +++ b/management/server/idp/keycloak.go @@ -0,0 +1,581 @@ +package idp + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/telemetry" + log "github.com/sirupsen/logrus" +) + +const ( + wtAccountID = "wt_account_id" + wtPendingInvite = "wt_pending_invite" +) + +// KeycloakManager keycloak manager client instance. +type KeycloakManager struct { + adminEndpoint string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +// KeycloakClientConfig keycloak manager client configurations. +type KeycloakClientConfig struct { + ClientID string + ClientSecret string + AdminEndpoint string + TokenEndpoint string + GrantType string +} + +// KeycloakCredentials keycloak authentication information. +type KeycloakCredentials struct { + clientConfig KeycloakClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + jwtToken JWTToken + mux sync.Mutex + appMetrics telemetry.AppMetrics +} + +// keycloakUserCredential describe the authentication method for, +// newly created user profile. +type keycloakUserCredential struct { + Type string `json:"type"` + Value string `json:"value"` + Temporary bool `json:"temporary"` +} + +// keycloakUserAttributes holds additional user data fields. +type keycloakUserAttributes map[string][]string + +// createUserRequest is a user create request. +type keycloakCreateUserRequest struct { + Email string `json:"email"` + Username string `json:"username"` + Enabled bool `json:"enabled"` + EmailVerified bool `json:"emailVerified"` + Credentials []keycloakUserCredential `json:"credentials"` + Attributes keycloakUserAttributes `json:"attributes"` +} + +// keycloakProfile represents an keycloak user profile response. +type keycloakProfile struct { + ID string `json:"id"` + CreatedTimestamp int64 `json:"createdTimestamp"` + Username string `json:"username"` + Email string `json:"email"` + Attributes keycloakUserAttributes `json:"attributes"` +} + +// NewKeycloakManager creates a new instance of the KeycloakManager. +func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + + helper := JsonParser{} + + if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" { + return nil, fmt.Errorf("keycloak idp configuration is not complete") + } + + if config.GrantType != "client_credentials" { + return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials") + } + + credentials := &KeycloakCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &KeycloakManager{ + adminEndpoint: config.AdminEndpoint, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak. +func (kc *KeycloakCredentials) jwtStillValid() bool { + return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime) +} + +// requestJWTToken performs request to get jwt token. +func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { + data := url.Values{} + data.Set("client_id", kc.clientConfig.ClientID) + data.Set("client_secret", kc.clientConfig.ClientSecret) + data.Set("grant_type", kc.clientConfig.GrantType) + + payload := strings.NewReader(data.Encode()) + req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload) + if err != nil { + return nil, err + } + req.Header.Add("content-type", "application/x-www-form-urlencoded") + + log.Debug("requesting new jwt token for keycloak idp manager") + + resp, err := kc.httpClient.Do(req) + if err != nil { + if kc.appMetrics != nil { + kc.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode) + } + + return resp, nil +} + +// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds +func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) { + jwtToken := JWTToken{} + body, err := io.ReadAll(rawBody) + if err != nil { + return jwtToken, err + } + + err = kc.helper.Unmarshal(body, &jwtToken) + if err != nil { + return jwtToken, err + } + + if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" { + return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) + } + + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + if err != nil { + return jwtToken, err + } + + // Exp maps into exp from jwt token + var IssuedAt struct{ Exp int64 } + err = kc.helper.Unmarshal(data, &IssuedAt) + if err != nil { + return jwtToken, err + } + jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0) + + return jwtToken, nil +} + +// Authenticate retrieves access token to use the keycloak Management API. +func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { + kc.mux.Lock() + defer kc.mux.Unlock() + + if kc.appMetrics != nil { + kc.appMetrics.IDPMetrics().CountAuthenticate() + } + + // reuse the token without requesting a new one if it is not expired, + // and if expiry time is sufficient time available to make a request. + if kc.jwtStillValid() { + return kc.jwtToken, nil + } + + resp, err := kc.requestJWTToken() + if err != nil { + return kc.jwtToken, err + } + defer resp.Body.Close() + + jwtToken, err := kc.parseRequestJWTResponse(resp.Body) + if err != nil { + return kc.jwtToken, err + } + + kc.jwtToken = jwtToken + + return kc.jwtToken, nil +} + +// CreateUser creates a new user in keycloak Idp and sends an invite. +func (km *KeycloakManager) CreateUser(email string, name string, accountID string) (*UserData, error) { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return nil, err + } + + invite := true + appMetadata := AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &invite, + } + + payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata) + if err != nil { + return nil, err + } + + reqURL := fmt.Sprintf("%s/users", km.adminEndpoint) + payload := strings.NewReader(payloadString) + + req, err := http.NewRequest(http.MethodPost, reqURL, payload) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountCreateUser() + } + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) + } + + locationHeader := resp.Header.Get("location") + userID, err := extractUserIDFromLocationHeader(locationHeader) + if err != nil { + return nil, err + } + + return km.GetUserDataByID(userID, appMetadata) +} + +// GetUserByEmail searches users with a given email. +// If no users have been found, this function returns an empty list. +func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { + q := url.Values{} + q.Add("email", email) + q.Add("exact", "true") + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles { + users = append(users, profile.userData()) + } + + return users, nil +} + +// GetUserDataByID requests user data from keycloak via ID. +func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := km.get("users/"+userID, nil) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var profile keycloakProfile + err = km.helper.Unmarshal(body, &profile) + if err != nil { + return nil, err + } + + return profile.userData(), nil +} + +// GetAccount returns all the users for a given profile. +func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { + q := url.Values{} + q.Add("q", wtAccountID+":"+accountID) + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetAccount() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles { + users = append(users, profile.userData()) + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { + totalUsers, err := km.totalUsersCount() + if err != nil { + return nil, err + } + + q := url.Values{} + q.Add("max", fmt.Sprint(*totalUsers)) + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range profiles { + userData := profile.userData() + + accountID := userData.AppMetadata.WTAccountID + if accountID != "" { + if _, ok := indexedUsers[accountID]; !ok { + indexedUsers[accountID] = make([]*UserData, 0) + } + indexedUsers[accountID] = append(indexedUsers[accountID], userData) + } + } + + return indexedUsers, nil +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return err + } + + attrs := keycloakUserAttributes{} + attrs.Set(wtAccountID, appMetadata.WTAccountID) + if appMetadata.WTPendingInvite != nil { + attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) + } else { + attrs.Set(wtPendingInvite, "false") + } + + reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID) + data, err := km.helper.Marshal(map[string]any{ + "attributes": attrs, + }) + if err != nil { + return err + } + payload := strings.NewReader(string(data)) + + req, err := http.NewRequest(http.MethodPut, reqURL, payload) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + log.Debugf("updating IdP metadata for user %s", userID) + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() + } + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode) + } + + return nil +} + +func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { + attrs := keycloakUserAttributes{} + attrs.Set(wtAccountID, appMetadata.WTAccountID) + attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) + + req := &keycloakCreateUserRequest{ + Email: email, + Username: name, + Enabled: true, + EmailVerified: true, + Credentials: []keycloakUserCredential{ + { + Type: "password", + Value: GeneratePassword(8, 1, 1, 1), + Temporary: false, + }, + }, + Attributes: attrs, + } + + str, err := json.Marshal(req) + if err != nil { + return "", err + } + + return string(str), nil +} + +// get perform Get requests. +func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return nil, err + } + + reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode()) + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// totalUsersCount returns the total count of all user created. +// Used when fetching all registered accounts with pagination. +func (km *KeycloakManager) totalUsersCount() (*int, error) { + body, err := km.get("users/count", nil) + if err != nil { + return nil, err + } + + count, err := strconv.Atoi(string(body)) + if err != nil { + return nil, err + } + + return &count, nil +} + +// extractUserIDFromLocationHeader extracts the user ID from the location, +// header once the user is created successfully +func extractUserIDFromLocationHeader(locationHeader string) (string, error) { + userURL, err := url.Parse(locationHeader) + if err != nil { + return "", err + } + + return path.Base(userURL.Path), nil +} + +// userData construct user data from keycloak profile. +func (kp keycloakProfile) userData() *UserData { + accountID := kp.Attributes.Get(wtAccountID) + pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite)) + if err != nil { + pendingInvite = false + } + + return &UserData{ + Email: kp.Email, + Name: kp.Username, + ID: kp.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pendingInvite, + }, + } +} + +// Set sets the key to value. It replaces any existing +// values. +func (ka keycloakUserAttributes) Set(key, value string) { + ka[key] = []string{value} +} + +// Get returns the first value associated with the given key. +// If there are no values associated with the key, Get returns +// the empty string. +func (ka keycloakUserAttributes) Get(key string) string { + if ka == nil { + return "" + } + + values := ka[key] + if len(values) == 0 { + return "" + } + return values[0] +} diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go new file mode 100644 index 000000000..00acf81bd --- /dev/null +++ b/management/server/idp/keycloak_test.go @@ -0,0 +1,401 @@ +package idp + +import ( + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewKeycloakManager(t *testing.T) { + type test struct { + name string + inputConfig KeycloakClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := KeycloakClientConfig{ + ClientID: "client_id", + ClientSecret: "client_secret", + AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123", + TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token", + GrantType: "client_credentials", + } + + testCase1 := test{ + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + testCase2Config := defaultTestConfig + testCase2Config.ClientID = "" + + testCase2 := test{ + name: "Missing ClientID Configuration", + inputConfig: testCase2Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + } + + testCase5Config := defaultTestConfig + testCase5Config.GrantType = "authorization_code" + + testCase5 := test{ + name: "Wrong GrantType", + inputConfig: testCase5Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when wrong grant type", + } + + for _, testCase := range []test{testCase1, testCase2, testCase5} { + t.Run(testCase.name, func(t *testing.T) { + _, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + }) + } +} + +type mockKeycloakCredentials struct { + jwtToken JWTToken + err error +} + +func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) { + return mc.jwtToken, mc.err +} + +func TestKeycloakRequestJWTToken(t *testing.T) { + + type requestJWTTokenTest struct { + name string + inputCode int + inputRespBody string + helper ManagerHelper + expectedFuncExitErrDiff error + expectedToken string + } + exp := 5 + token := newTestJWT(t, exp) + + requestJWTTokenTesttCase1 := requestJWTTokenTest{ + name: "Good JWT Response", + inputCode: 200, + inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedToken: token, + } + requestJWTTokenTestCase2 := requestJWTTokenTest{ + name: "Request Bad Status Code", + inputCode: 400, + inputRespBody: "{}", + helper: JsonParser{}, + expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"), + expectedToken: "", + } + + for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + + jwtReqClient := mockHTTPClient{ + resBody: testCase.inputRespBody, + code: testCase.inputCode, + } + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + httpClient: &jwtReqClient, + helper: testCase.helper, + } + + resp, err := creds.requestJWTToken() + if err != nil { + if testCase.expectedFuncExitErrDiff != nil { + assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") + } else { + t.Fatal(err) + } + } else { + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err, "unable to read the response body") + + jwtToken := JWTToken{} + err = testCase.helper.Unmarshal(body, &jwtToken) + assert.NoError(t, err, "unable to parse the json input") + + assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same") + } + }) + } +} + +func TestKeycloakParseRequestJWTResponse(t *testing.T) { + type parseRequestJWTResponseTest struct { + name string + inputRespBody string + helper ManagerHelper + expectedToken string + expectedExpiresIn int + assertErrFunc assert.ErrorAssertionFunc + assertErrFuncMessage string + } + + exp := 100 + token := newTestJWT(t, exp) + + parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{ + name: "Parse Good JWT Body", + inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedToken: token, + expectedExpiresIn: exp, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "no error was expected", + } + parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ + name: "Parse Bad json JWT Body", + inputRespBody: "", + helper: JsonParser{}, + expectedToken: "", + expectedExpiresIn: 0, + assertErrFunc: assert.Error, + assertErrFuncMessage: "json error was expected", + } + + for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody)) + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + helper: testCase.helper, + } + jwtToken, err := creds.parseRequestJWTResponse(rawBody) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + + assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same") + assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same") + }) + } +} + +func TestKeycloakJwtStillValid(t *testing.T) { + type jwtStillValidTest struct { + name string + inputTime time.Time + expectedResult bool + message string + } + + jwtStillValidTestCase1 := jwtStillValidTest{ + name: "JWT still valid", + inputTime: time.Now().Add(10 * time.Second), + expectedResult: true, + message: "should be true", + } + jwtStillValidTestCase2 := jwtStillValidTest{ + name: "JWT is invalid", + inputTime: time.Now(), + expectedResult: false, + message: "should be false", + } + + for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + } + creds.jwtToken.expiresInTime = testCase.inputTime + + assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message) + }) + } +} + +func TestKeycloakAuthenticate(t *testing.T) { + type authenticateTest struct { + name string + inputCode int + inputResBody string + inputExpireToken time.Time + helper ManagerHelper + expectedFuncExitErrDiff error + expectedCode int + expectedToken string + } + exp := 5 + token := newTestJWT(t, exp) + + authenticateTestCase1 := authenticateTest{ + name: "Get Cached token", + inputExpireToken: time.Now().Add(30 * time.Second), + helper: JsonParser{}, + expectedFuncExitErrDiff: nil, + expectedCode: 200, + expectedToken: "", + } + + authenticateTestCase2 := authenticateTest{ + name: "Get Good JWT Response", + inputCode: 200, + inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedCode: 200, + expectedToken: token, + } + + authenticateTestCase3 := authenticateTest{ + name: "Get Bad Status Code", + inputCode: 400, + inputResBody: "{}", + helper: JsonParser{}, + expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"), + expectedCode: 200, + expectedToken: "", + } + + for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} { + t.Run(testCase.name, func(t *testing.T) { + + jwtReqClient := mockHTTPClient{ + resBody: testCase.inputResBody, + code: testCase.inputCode, + } + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + httpClient: &jwtReqClient, + helper: testCase.helper, + } + creds.jwtToken.expiresInTime = testCase.inputExpireToken + + _, err := creds.Authenticate() + if err != nil { + if testCase.expectedFuncExitErrDiff != nil { + assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") + } else { + t.Fatal(err) + } + } + + assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same") + }) + } +} + +func TestKeycloakUpdateUserAppMetadata(t *testing.T) { + type updateUserAppMetadataTest struct { + name string + inputReqBody string + expectedReqBody string + appMetadata AppMetadata + statusCode int + helper ManagerHelper + managerCreds ManagerCredentials + assertErrFunc assert.ErrorAssertionFunc + assertErrFuncMessage string + } + + appMetadata := AppMetadata{WTAccountID: "ok"} + + updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ + name: "Bad Authentication", + expectedReqBody: "", + appMetadata: appMetadata, + statusCode: 400, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + err: fmt.Errorf("error"), + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ + name: "Bad Status Code", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), + appMetadata: appMetadata, + statusCode: 400, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{ + name: "Bad Response Parsing", + statusCode: 400, + helper: &mockJsonParser{marshalErrorString: "error"}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ + name: "Good request", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), + appMetadata: appMetadata, + statusCode: 204, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + invite := true + updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ + name: "Update Pending Invite", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID), + appMetadata: AppMetadata{ + WTAccountID: "ok", + WTPendingInvite: &invite, + }, + statusCode: 204, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, + updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} { + t.Run(testCase.name, func(t *testing.T) { + reqClient := mockHTTPClient{ + resBody: testCase.inputReqBody, + code: testCase.statusCode, + } + + manager := &KeycloakManager{ + httpClient: &reqClient, + credentials: testCase.managerCreds, + helper: testCase.helper, + } + + err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + + assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match") + }) + } +} From 628b497e8198017950fe4ca5fe3facae2952586e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 23 Mar 2023 16:35:06 +0100 Subject: [PATCH 02/14] Adjustments for the change server flow (#756) Check SSO support by calling the internal.GetDeviceAuthorizationFlowInfo Rename LoginSaveConfigIfSSOSupported to SaveConfigIfSSOSupported Receive device name as input for setup-key login have a default android name when no context value is provided log non parsed errors from management registration calls --- client/android/login.go | 52 +++++++++++++++++++++------------ client/system/info_android.go | 2 +- management/server/grpcserver.go | 11 ++++--- 3 files changed, 42 insertions(+), 23 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index e4cb5513d..4e2f1ab30 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -3,14 +3,18 @@ package android import ( "context" "fmt" - "github.com/cenkalti/backoff/v4" - "github.com/netbirdio/netbird/client/cmd" "time" - "github.com/netbirdio/netbird/client/internal" + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/cmd" + "github.com/netbirdio/netbird/client/system" + + "github.com/netbirdio/netbird/client/internal" ) // URLOpener it is a callback interface. The Open function will be triggered if @@ -52,32 +56,44 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { } } -// LoginAndSaveConfigIfSSOSupported test the connectivity with the management server. -// If the SSO is supported than save the configuration. Return with the SSO login is supported or not. -func (a *Auth) LoginAndSaveConfigIfSSOSupported() (bool, error) { - var needsLogin bool +// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. +// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO +// is not supported and returns false without saving the configuration. For other errors return false. +func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { + supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey) - return + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + supportsSSO = false + err = nil + } + return err }) + + if !supportsSSO { + return false, nil + } + if err != nil { return false, fmt.Errorf("backoff cycle failed: %v", err) } - if !needsLogin { - return false, nil - } + err = internal.WriteOutConfig(a.cfgPath, a.config) - return needsLogin, err + return true, err } // LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. -func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error { +func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + //nolint + ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) + err := a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, setupKey, "") - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil + backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") + if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { + // we got an answer from management, exit backoff earlier + return backoff.Permanent(backoffErr) } - return err + return backoffErr }) if err != nil { return fmt.Errorf("backoff cycle failed: %v", err) diff --git a/client/system/info_android.go b/client/system/info_android.go index 65fb409f6..9ea9c0487 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -34,7 +34,7 @@ func GetInfo(ctx context.Context) *Info { func extractDeviceName(ctx context.Context) string { v, ok := ctx.Value(DeviceNameCtxKey).(string) if !ok { - return "" + return "android" } return v } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index fa0e49ed3..45be9815c 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,24 +3,26 @@ package server import ( "context" "fmt" - pb "github.com/golang/protobuf/proto" //nolint "strings" "time" + pb "github.com/golang/protobuf/proto" //nolint + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/golang/protobuf/ptypes/timestamp" - "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/proto" - internalStatus "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" gRPCPeer "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" + internalStatus "github.com/netbirdio/netbird/management/server/status" ) // GRPCServer an instance of a Management gRPC API server @@ -222,6 +224,7 @@ func mapError(err error) error { default: } } + log.Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } From e6292e3124f85fffce6042547ffaea0d3e45d7c0 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 23 Mar 2023 17:47:53 +0100 Subject: [PATCH 03/14] Disable peer expiration of peers added with setup keys (#758) --- management/server/account.go | 4 ++-- management/server/account_test.go | 37 +++++++++++++++++++++++++++++++ management/server/peer.go | 2 +- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 1d4c10721..01cae2e64 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -358,11 +358,11 @@ func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { return *nextExpiry, true } -// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user func (a *Account) GetPeersWithExpiration() []*Peer { peers := make([]*Peer, 0) for _, peer := range a.Peers { - if peer.LoginExpirationEnabled { + if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { peers = append(peers, peer) } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 5b4b1cc17..af894817b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1605,9 +1605,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { peers: map[string]*Peer{ "peer-1": { LoginExpirationEnabled: false, + UserID: userID, }, "peer-2": { LoginExpirationEnabled: false, + UserID: userID, }, }, expectedPeers: map[string]struct{}{}, @@ -1618,9 +1620,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { "peer-1": { ID: "peer-1", LoginExpirationEnabled: true, + UserID: userID, }, "peer-2": { LoginExpirationEnabled: false, + UserID: userID, }, }, expectedPeers: map[string]struct{}{ @@ -1680,12 +1684,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { Connected: false, }, LoginExpirationEnabled: true, + UserID: userID, }, "peer-2": { Status: &PeerStatus{ Connected: true, }, LoginExpirationEnabled: false, + UserID: userID, }, }, expiration: time.Second, @@ -1701,12 +1707,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { Connected: true, }, LoginExpirationEnabled: false, + UserID: userID, }, "peer-2": { Status: &PeerStatus{ Connected: true, }, LoginExpirationEnabled: false, + UserID: userID, }, }, expiration: time.Second, @@ -1723,6 +1731,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: true, }, LoginExpirationEnabled: true, + UserID: userID, }, "peer-2": { Status: &PeerStatus{ @@ -1730,6 +1739,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: true, }, LoginExpirationEnabled: true, + UserID: userID, }, }, expiration: time.Second, @@ -1747,6 +1757,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { }, LoginExpirationEnabled: true, LastLogin: time.Now(), + UserID: userID, }, "peer-2": { Status: &PeerStatus{ @@ -1754,6 +1765,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: true, }, LoginExpirationEnabled: true, + UserID: userID, }, }, expiration: time.Minute, @@ -1761,6 +1773,31 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { expectedNextRun: true, expectedNextExpiration: expectedNextExpiration, }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: false, + }, + LoginExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: false, + }, + LoginExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { diff --git a/management/server/peer.go b/management/server/peer.go index b5505f912..7b5ca539f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -528,7 +528,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* SSHEnabled: false, SSHKey: peer.SSHKey, LastLogin: time.Now(), - LoginExpirationEnabled: true, + LoginExpirationEnabled: addedByUser, } // add peer to 'All' group From a27fe4326c32be4a44ea2b415bb787a65ec851d3 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 23 Mar 2023 18:26:41 +0100 Subject: [PATCH 04/14] Add JWT middleware validation failure log (#760) We will log the middleware log now, but in the next releases we should provide a generic error that can be parsed by the dashboard. --- management/server/http/middleware/jwt.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go index feb00ec86..1ac6d3948 100644 --- a/management/server/http/middleware/jwt.go +++ b/management/server/http/middleware/jwt.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - "log" "net/http" "strings" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/status" ) // A function called whenever an error is encountered @@ -114,6 +116,9 @@ func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { // If there was an error, do not continue. if err != nil { + log.Errorf("received an error while validating the JWT token: %s. "+ + "Review your IDP configuration and ensure that "+ + "settings are in sync between dashboard and management", err) return } From d1703479ff236d56d17fe8daf66199c99590497a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 24 Mar 2023 08:40:39 +0100 Subject: [PATCH 05/14] Add custom ice stdnet implementation (#754) On Android, because of the hard SELinux policies can not list the interfaces of the ICE package. Without it can not generate a host type candidate. In this pull request, the list of interfaces comes via the Java interface. --- .github/workflows/golang-test-linux.yml | 4 +- client/android/client.go | 12 +- client/cmd/up.go | 2 +- client/internal/connect.go | 8 +- client/internal/engine.go | 19 +++- client/internal/engine_stdnet.go | 11 ++ client/internal/engine_stdnet_android.go | 7 ++ client/internal/peer/conn.go | 15 ++- client/internal/peer/conn_test.go | 12 +- client/internal/peer/stdnet.go | 11 ++ client/internal/peer/stdnet_android.go | 7 ++ client/internal/stdnet/iface_discover.go | 8 ++ client/internal/stdnet/stdnet.go | 137 +++++++++++++++++++++++ client/internal/stdnet/stdnet_test.go | 52 +++++++++ client/server/server.go | 4 +- 15 files changed, 285 insertions(+), 24 deletions(-) create mode 100644 client/internal/engine_stdnet.go create mode 100644 client/internal/engine_stdnet_android.go create mode 100644 client/internal/peer/stdnet.go create mode 100644 client/internal/peer/stdnet_android.go create mode 100644 client/internal/stdnet/iface_discover.go create mode 100644 client/internal/stdnet/stdnet.go create mode 100644 client/internal/stdnet/stdnet_test.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 4186baf38..d600575e6 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -72,7 +72,7 @@ jobs: run: go test -c -o routemanager-testing.bin ./client/internal/routemanager/... - name: Generate Engine Test bin - run: go test -c -o engine-testing.bin ./client/internal/*.go + run: go test -c -o engine-testing.bin ./client/internal - name: Generate Peer Test bin run: go test -c -o peer-testing.bin ./client/internal/peer/... @@ -89,4 +89,4 @@ jobs: run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 \ No newline at end of file + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/client/android/client.go b/client/android/client.go index 778c3d15a..ac16316ed 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/iface" @@ -23,6 +24,11 @@ type TunAdapter interface { iface.TunAdapter } +// IFaceDiscover export internal IFaceDiscover for mobile +type IFaceDiscover interface { + stdnet.IFaceDiscover +} + func init() { formatter.SetLogcatFormatter(log.StandardLogger()) } @@ -31,6 +37,7 @@ func init() { type Client struct { cfgFile string tunAdapter iface.TunAdapter + iFaceDiscover IFaceDiscover recorder *peer.Status ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex @@ -38,7 +45,7 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client { +func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client { lvl, _ := log.ParseLevel("trace") log.SetLevel(lvl) @@ -46,6 +53,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client { cfgFile: cfgFile, deviceName: deviceName, tunAdapter: tunAdapter, + iFaceDiscover: iFaceDiscover, recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, } @@ -77,7 +85,7 @@ func (c *Client) Run(urlOpener URLOpener) error { // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter) + return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover) } // Stop the internal client and free the resources diff --git a/client/cmd/up.go b/client/cmd/up.go index 5bbdab690..fc576e8d4 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil) + return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil) } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { diff --git a/client/internal/connect.go b/client/internal/connect.go index eeb0e640e..3aca0bab9 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -13,6 +13,7 @@ import ( gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/iface" @@ -22,7 +23,7 @@ import ( ) // RunClient with main logic. -func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error { +func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error { backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -146,7 +147,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter) + engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover) if err != nil { log.Error(err) return wrapErr(err) @@ -193,12 +194,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) { engineConf := &EngineConfig{ WgIfaceName: config.WgIface, WgAddr: peerConfig.Address, TunAdapter: tunAdapter, + IFaceDiscover: iFaceDiscover, IFaceBlackList: config.IFaceBlackList, DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, diff --git a/client/internal/engine.go b/client/internal/engine.go index 10d74d931..a7fa82c11 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/client/internal/stdnet" nbssh "github.com/netbirdio/netbird/client/ssh" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" @@ -49,6 +50,8 @@ type EngineConfig struct { // TunAdapter is option. It is necessary for mobile version. TunAdapter iface.TunAdapter + IFaceDiscover stdnet.IFaceDiscover + // WgAddr is a Wireguard local address (Netbird Network IP) WgAddr string @@ -186,12 +189,22 @@ func (e *Engine) Start() error { networkName = "udp4" } + transportNet, err := e.newStdNet() + if err != nil { + log.Warnf("failed to create pion's stdnet: %s", err) + } + e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort}) if err != nil { log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error()) e.close() return err } + udpMuxParams := ice.UDPMuxParams{ + UDPConn: e.udpMuxConn, + Net: transportNet, + } + e.udpMux = ice.NewUDPMuxDefault(udpMuxParams) e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort}) if err != nil { @@ -199,9 +212,7 @@ func (e *Engine) Start() error { e.close() return err } - - e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn}) - e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx}) + e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet}) err = e.wgInterface.Create() if err != nil { @@ -813,7 +824,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er NATExternalIPs: e.parseNATExternalIPMappings(), } - peerConn, err := peer.NewConn(config, e.statusRecorder) + peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover) if err != nil { return nil, err } diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go new file mode 100644 index 000000000..b4e05768c --- /dev/null +++ b/client/internal/engine_stdnet.go @@ -0,0 +1,11 @@ +//go:build !android + +package internal + +import ( + "github.com/pion/transport/v2/stdnet" +) + +func (e *Engine) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNet() +} diff --git a/client/internal/engine_stdnet_android.go b/client/internal/engine_stdnet_android.go new file mode 100644 index 000000000..976ffd656 --- /dev/null +++ b/client/internal/engine_stdnet_android.go @@ -0,0 +1,7 @@ +package internal + +import "github.com/netbirdio/netbird/client/internal/stdnet" + +func (e *Engine) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNet(e.config.IFaceDiscover) +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index e42e6305d..ee45a6ba0 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -9,15 +9,15 @@ import ( "time" "github.com/pion/ice/v2" - "github.com/pion/transport/v2/stdnet" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl" "github.com/netbirdio/netbird/client/internal/proxy" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/version" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + "github.com/netbirdio/netbird/version" ) // ConnConfig is a peer Connection configuration @@ -93,6 +93,9 @@ type Conn struct { proxy proxy.Proxy remoteModeCh chan ModeMessage meta meta + + adapter iface.TunAdapter + iFaceDiscover stdnet.IFaceDiscover } // meta holds meta information about a connection @@ -118,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) { +func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) { return &Conn{ config: config, mu: sync.Mutex{}, @@ -128,6 +131,8 @@ func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) { remoteAnswerCh: make(chan OfferAnswer), statusRecorder: statusRecorder, remoteModeCh: make(chan ModeMessage, 1), + adapter: adapter, + iFaceDiscover: iFaceDiscover, }, nil } @@ -162,7 +167,9 @@ func (conn *Conn) reCreateAgent() error { defer conn.mu.Unlock() failedTimeout := 6 * time.Second - transportNet, err := stdnet.NewNet() + + var err error + transportNet, err := conn.newStdNet() if err != nil { log.Warnf("failed to create pion's stdnet: %s", err) } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 7f9b263e4..ddee91800 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -37,7 +37,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - conn, err := NewConn(connConf, nil) + conn, err := NewConn(connConf, nil, nil, nil) if err != nil { return } @@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder("https://mgm")) + conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) if err != nil { return } @@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder("https://mgm")) + conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) if err != nil { return } @@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder("https://mgm")) + conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) if err != nil { return } @@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) { func TestConn_Close(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder("https://mgm")) + conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil) if err != nil { return } @@ -411,7 +411,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { g := errgroup.Group{} - conn, err := NewConn(connConf, nil) + conn, err := NewConn(connConf, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go new file mode 100644 index 000000000..588aaa929 --- /dev/null +++ b/client/internal/peer/stdnet.go @@ -0,0 +1,11 @@ +//go:build !android + +package peer + +import ( + "github.com/pion/transport/v2/stdnet" +) + +func (conn *Conn) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNet() +} diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go new file mode 100644 index 000000000..71a962c21 --- /dev/null +++ b/client/internal/peer/stdnet_android.go @@ -0,0 +1,7 @@ +package peer + +import "github.com/netbirdio/netbird/client/internal/stdnet" + +func (conn *Conn) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNet(conn.iFaceDiscover) +} diff --git a/client/internal/stdnet/iface_discover.go b/client/internal/stdnet/iface_discover.go new file mode 100644 index 000000000..cbe306e2e --- /dev/null +++ b/client/internal/stdnet/iface_discover.go @@ -0,0 +1,8 @@ +package stdnet + +// IFaceDiscover provide an option for external services (mobile) +// to collect network interface information +type IFaceDiscover interface { + // IFaces return with the description of the interfaces + IFaces() (string, error) +} diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go new file mode 100644 index 000000000..311306535 --- /dev/null +++ b/client/internal/stdnet/stdnet.go @@ -0,0 +1,137 @@ +// Package stdnet is an extension of the pion's stdnet. +// With it the list of the interface can come from external source. +// More info: https://github.com/golang/go/issues/40569 +package stdnet + +import ( + "fmt" + "net" + "strings" + + "github.com/pion/transport/v2" + "github.com/pion/transport/v2/stdnet" + log "github.com/sirupsen/logrus" +) + +// Net is an implementation of the net.Net interface +// based on functions of the standard net package. +type Net struct { + stdnet.Net + interfaces []*transport.Interface +} + +// NewNet creates a new StdNet instance. +func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) { + n := &Net{} + + return n, n.UpdateInterfaces(iFaceDiscover) +} + +// UpdateInterfaces updates the internal list of network interfaces +// and associated addresses. +func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error { + ifacesString, err := iFaceDiscover.IFaces() + if err != nil { + return err + } + n.interfaces = parseInterfacesString(ifacesString) + return err +} + +// Interfaces returns a slice of interfaces which are available on the +// system +func (n *Net) Interfaces() ([]*transport.Interface, error) { + return n.interfaces, nil +} + +// InterfaceByIndex returns the interface specified by index. +// +// On Solaris, it returns one of the logical network interfaces +// sharing the logical data link; for more precision use +// InterfaceByName. +func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { + for _, ifc := range n.interfaces { + if ifc.Index == index { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) +} + +// InterfaceByName returns the interface specified by name. +func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { + for _, ifc := range n.interfaces { + if ifc.Name == name { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) +} + +func parseInterfacesString(interfaces string) []*transport.Interface { + ifs := []*transport.Interface{} + + for _, iface := range strings.Split(interfaces, "\n") { + if strings.TrimSpace(iface) == "" { + continue + } + + fields := strings.Split(iface, "|") + if len(fields) != 2 { + log.Warnf("parseInterfacesString: unable to split %q", iface) + continue + } + + var name string + var index, mtu int + var up, broadcast, loopback, pointToPoint, multicast bool + _, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t", + &name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast) + if err != nil { + log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err) + continue + } + + newIf := net.Interface{ + Name: name, + Index: index, + MTU: mtu, + } + if up { + newIf.Flags |= net.FlagUp + } + if broadcast { + newIf.Flags |= net.FlagBroadcast + } + if loopback { + newIf.Flags |= net.FlagLoopback + } + if pointToPoint { + newIf.Flags |= net.FlagPointToPoint + } + if multicast { + newIf.Flags |= net.FlagMulticast + } + + ifc := transport.NewInterface(newIf) + + addrs := strings.Trim(fields[1], " \n") + foundAddress := false + for _, addr := range strings.Split(addrs, " ") { + ip, ipNet, err := net.ParseCIDR(addr) + if err != nil { + log.Warnf("%s", err) + continue + } + ipNet.IP = ip + ifc.AddAddress(ipNet) + foundAddress = true + } + if foundAddress { + ifs = append(ifs, ifc) + } + } + return ifs +} diff --git a/client/internal/stdnet/stdnet_test.go b/client/internal/stdnet/stdnet_test.go new file mode 100644 index 000000000..f3c09c61e --- /dev/null +++ b/client/internal/stdnet/stdnet_test.go @@ -0,0 +1,52 @@ +package stdnet + +import ( + "fmt" + "testing" +) + +func Test_parseInterfacesString(t *testing.T) { + testData := []struct { + name string + index int + mtu int + up bool + broadcast bool + loopBack bool + pointToPoint bool + multicast bool + addr string + }{ + {"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"}, + {"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"}, + } + + var exampleString string + for _, d := range testData { + exampleString = fmt.Sprintf("%s\n%s %d %d %t %t %t %t %t | %s", exampleString, + d.name, + d.index, + d.mtu, + d.up, + d.broadcast, + d.loopBack, + d.pointToPoint, + d.multicast, + d.addr) + } + nets := parseInterfacesString(exampleString) + if len(nets) == 0 { + t.Fatalf("failed to parse interfaces") + } + + for i, net := range nets { + if net.MTU != testData[i].mtu { + t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu) + + } + + if net.Interface.Name != testData[i].name { + t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name) + } + } +} diff --git a/client/server/server.go b/client/server/server.go index 238b15acc..44502b148 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -102,7 +102,7 @@ func (s *Server) Start() error { } go func() { - if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil { + if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil { log.Errorf("init connections: %v", err) } }() @@ -394,7 +394,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } go func() { - if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil { + if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil { log.Errorf("run client connection: %v", err) return } From 992cfe64e1fda90717240716c47f85251beb1d83 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 24 Mar 2023 10:46:40 +0100 Subject: [PATCH 06/14] Add ipv6 test for stdnet pkg (#761) --- client/internal/stdnet/stdnet_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/client/internal/stdnet/stdnet_test.go b/client/internal/stdnet/stdnet_test.go index f3c09c61e..f6e0cfbcb 100644 --- a/client/internal/stdnet/stdnet_test.go +++ b/client/internal/stdnet/stdnet_test.go @@ -19,6 +19,7 @@ func Test_parseInterfacesString(t *testing.T) { }{ {"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"}, {"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"}, + {"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"}, } var exampleString string @@ -48,5 +49,18 @@ func Test_parseInterfacesString(t *testing.T) { if net.Interface.Name != testData[i].name { t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name) } + + addr, err := net.Addrs() + if err != nil { + t.Fatal(err) + } + + if len(addr) == 0 { + t.Errorf("invalid address parsing") + } + + if addr[0].String() != testData[i].addr { + t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr) + } } } From 71d24e59e6ce84e45bb01a00559b7e29528f6bba Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 24 Mar 2023 18:51:35 +0100 Subject: [PATCH 07/14] Add fqdn and address for notification listener (#757) Extend the status notification listeners with FQDN and address changes. It is required for mobile services. --- client/internal/peer/listener.go | 1 + client/internal/peer/notifier.go | 9 +++++++++ client/internal/peer/status.go | 10 ++++++++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/client/internal/peer/listener.go b/client/internal/peer/listener.go index 9324c6773..c8dc0fe70 100644 --- a/client/internal/peer/listener.go +++ b/client/internal/peer/listener.go @@ -5,5 +5,6 @@ type Listener interface { OnConnected() OnDisconnected() OnConnecting() + OnAddressChanged(string, string) OnPeersListChanged(int) } diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index db1c32e97..efc9e47ad 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -122,3 +122,12 @@ func (n *notifier) peerListChanged(numOfPeers int) { l.OnPeersListChanged(numOfPeers) } } + +func (n *notifier) localAddressChanged(fqdn, address string) { + n.listenersLock.Lock() + defer n.listenersLock.Unlock() + + for l := range n.listeners { + l.OnAddressChanged(fqdn, address) + } +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index b0a3f338e..1ecdff301 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -190,6 +190,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { defer d.mux.Unlock() d.localPeer = localPeerState + d.notifyAddressChanged() } // CleanLocalPeerState cleans local peer status @@ -198,6 +199,7 @@ func (d *Status) CleanLocalPeerState() { defer d.mux.Unlock() d.localPeer = LocalPeerState{} + d.notifyAddressChanged() } // MarkManagementDisconnected sets ManagementState to disconnected @@ -215,7 +217,7 @@ func (d *Status) MarkManagementConnected() { defer d.mux.Unlock() defer d.onConnectionChanged() - d.managementState = true + d.managementState = true } // UpdateSignalAddress update the address of the signal server @@ -238,7 +240,7 @@ func (d *Status) MarkSignalDisconnected() { defer d.mux.Unlock() defer d.onConnectionChanged() - d.signalState = false + d.signalState = false } // MarkSignalConnected sets SignalState to connected @@ -303,3 +305,7 @@ func (d *Status) onConnectionChanged() { func (d *Status) notifyPeerListChanged() { d.notifier.peerListChanged(len(d.peers)) } + +func (d *Status) notifyAddressChanged() { + d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) +} From 55ebf93815d3bc1142d82c92c8ff8ad5a36f25bc Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 27 Mar 2023 15:37:58 +0200 Subject: [PATCH 08/14] Fix nil pointer exception when create config (#765) The config stored in a wrong variable when has been generated a new config --- client/server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/server/server.go b/client/server/server.go index 44502b148..6d5a08c59 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -78,7 +78,7 @@ func (s *Server) Start() error { // on failure we return error to retry config, err := internal.UpdateConfig(s.latestConfigInput) if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound { - config, err = internal.UpdateOrCreateConfig(s.latestConfigInput) + s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput) if err != nil { log.Warnf("unable to create configuration file: %v", err) return err From 488d338ce83d98921964a0ea0e853688401919b4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Mar 2023 09:57:23 +0200 Subject: [PATCH 09/14] Refactor the authentication part of mobile exports (#759) Refactor the auth code into async calls for mobile framework --------- Co-authored-by: Maycon Santos --- client/android/client.go | 2 +- client/android/login.go | 51 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index ac16316ed..5e3c0c85a 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -78,7 +78,7 @@ func (c *Client) Run(urlOpener URLOpener) error { c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.Login(urlOpener) + err = auth.login(urlOpener) if err != nil { return err } diff --git a/client/android/login.go b/client/android/login.go index 4e2f1ab30..0c11c0cce 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -17,6 +17,18 @@ import ( "github.com/netbirdio/netbird/client/internal" ) +// SSOListener is async listener for mobile framework +type SSOListener interface { + OnSuccess(bool) + OnError(error) +} + +// ErrListener is async listener for mobile framework +type ErrListener interface { + OnSuccess() + OnError(error) +} + // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { @@ -59,7 +71,18 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { // SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // is not supported and returns false without saving the configuration. For other errors return false. -func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { +func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { + go func() { + sso, err := a.saveConfigIfSSOSupported() + if err != nil { + listener.OnError(err) + } else { + listener.OnSuccess(sso) + } + }() +} + +func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) @@ -83,7 +106,18 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { } // LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. -func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { +func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) { + go func() { + err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) @@ -103,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(urlOpener URLOpener) error { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { + go func() { + err := a.login(urlOpener) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) login(urlOpener URLOpener) error { var needsLogin bool // check if we need to generate JWT token From 8ebd6ce9632f1aa533148ac3bec3f48bb6cfe419 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 29 Mar 2023 10:39:54 +0200 Subject: [PATCH 10/14] Add OnDisconnecting service callback (#767) Add OnDisconnecting service callback for mobile --- client/internal/connect.go | 1 + client/internal/peer/listener.go | 1 + client/internal/peer/notifier.go | 17 ++++++++++++++++- client/internal/peer/status.go | 5 +++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 3aca0bab9..47c63e6d0 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -164,6 +164,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, state.Set(StatusConnected) <-engineCtx.Done() + statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/peer/listener.go b/client/internal/peer/listener.go index c8dc0fe70..c601fe534 100644 --- a/client/internal/peer/listener.go +++ b/client/internal/peer/listener.go @@ -5,6 +5,7 @@ type Listener interface { OnConnected() OnDisconnected() OnConnecting() + OnDisconnecting() OnAddressChanged(string, string) OnPeersListChanged(int) } diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index efc9e47ad..4e618d2f8 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -8,6 +8,7 @@ const ( stateDisconnected = iota stateConnected stateConnecting + stateDisconnecting ) type notifier struct { @@ -57,8 +58,12 @@ func (n *notifier) updateServerStates(mgmState bool, signalState bool) { } n.currentServerState = newState - n.lastNotification = n.calculateState(newState, n.currentClientState) + if n.lastNotification == stateDisconnecting { + return + } + + n.lastNotification = n.calculateState(newState, n.currentClientState) go n.notifyAll(n.lastNotification) } @@ -78,6 +83,14 @@ func (n *notifier) clientStop() { go n.notifyAll(n.lastNotification) } +func (n *notifier) clientTearDown() { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + n.currentClientState = false + n.lastNotification = stateDisconnecting + go n.notifyAll(n.lastNotification) +} + func (n *notifier) isServerStateChanged(newState bool) bool { return n.currentServerState != newState } @@ -99,6 +112,8 @@ func (n *notifier) notifyListener(l Listener, state int) { l.OnConnected() case stateConnecting: l.OnConnecting() + case stateDisconnecting: + l.OnDisconnecting() } } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 1ecdff301..62841d6fc 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -288,6 +288,11 @@ func (d *Status) ClientStop() { d.notifier.clientStop() } +// ClientTeardown will notify all listeners about the service is under teardown +func (d *Status) ClientTeardown() { + d.notifier.clientTearDown() +} + // AddConnectionListener add a listener to the notifier func (d *Status) AddConnectionListener(listener Listener) { d.notifier.addListener(listener) From ab0cf1b8aa1ebf1d9701f65ca503eb505cb7ca20 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 29 Mar 2023 10:40:31 +0200 Subject: [PATCH 11/14] Fix slice bounds out of range in msg decryption (#768) --- encryption/encryption.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/encryption/encryption.go b/encryption/encryption.go index 196c42106..1c6ec7806 100644 --- a/encryption/encryption.go +++ b/encryption/encryption.go @@ -3,10 +3,13 @@ package encryption import ( "crypto/rand" "fmt" + "golang.org/x/crypto/nacl/box" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const nonceSize = 24 + // A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service // These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate) // Wireguard keys are used for encryption @@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. if err != nil { return nil, err } - copy(nonce[:], encryptedMsg[:24]) - opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) + if len(encryptedMsg) < nonceSize { + return nil, fmt.Errorf("invalid encrypted message lenght") + } + copy(nonce[:], encryptedMsg[:nonceSize]) + opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) if !ok { return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String()) } @@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. } // Generates nonce of size 24 -func genNonce() (*[24]byte, error) { - var nonce [24]byte +func genNonce() (*[nonceSize]byte, error) { + var nonce [nonceSize]byte if _, err := rand.Read(nonce[:]); err != nil { return nil, err } From dfb7960cd474d3475c13c06e535f36f57d935a7d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 29 Mar 2023 10:41:14 +0200 Subject: [PATCH 12/14] Fix pre-shared key query name for android configuration (#773) --- iface/ipc_parser_android.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iface/ipc_parser_android.go b/iface/ipc_parser_android.go index ef757a638..e1dd66856 100644 --- a/iface/ipc_parser_android.go +++ b/iface/ipc_parser_android.go @@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { if p.PresharedKey != nil { preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) - sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey)) + sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) } if p.Remove { From a7519859bccf2f4a02c672437ed3bd4196ace724 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 14:15:44 +0200 Subject: [PATCH 13/14] fix test --- management/server/user_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/management/server/user_test.go b/management/server/user_test.go index 1dd12e57b..238aa2bff 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -197,8 +197,4 @@ func TestUser_GetAllPATs(t *testing.T) { } assert.Equal(t, 2, len(pats)) - assert.Equal(t, mockTokenID1, pats[0].ID) - assert.Equal(t, mockToken1, pats[0].HashedToken) - assert.Equal(t, mockTokenID2, pats[1].ID) - assert.Equal(t, mockToken2, pats[1].HashedToken) } From 5e2f66d59142750c2ce9491af4b5d2dd607e12c4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 15:23:24 +0200 Subject: [PATCH 14/14] fix codacy --- management/server/account.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ce2d2fc1e..2c146b3ad 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -67,10 +67,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) - DeletePAT(accountID string, executingUserID string, targetUserId string, tokenID string) error - GetPAT(accountID string, executingUserID string, targetUserId string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(accountID string, executingUserID string, targetUserId string) ([]*PersonalAccessToken, error) + CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error + GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error)