Optimize Cache and IDP Management (#1147)

This pull request modifies the IdP and cache manager(s) to prevent the sending of app metadata
 to the upstream IDP on self-hosted instances. 
As a result, the IdP will now load all users from the IdP without filtering based on accountID.

We disable user invites as the administrator's own IDP system manages them.
This commit is contained in:
Bethuel Mmbaga 2023-10-03 17:40:28 +03:00 committed by GitHub
parent a952e7c72f
commit e26ec0b937
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 182 additions and 1670 deletions

View File

@ -988,6 +988,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return err return err
} }
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
userData[accountID.Id] = append(data, user)
}
}
}
delete(userData, idp.UnsetAccountID)
for accountID, users := range userData { for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil { if err != nil {

View File

@ -210,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
var pendingInvite bool
if appMetadata.WTPendingInvite != nil {
pendingInvite = *appMetadata.WTPendingInvite
}
patchedUserReq := api.PatchedUserRequest{
Attributes: map[string]interface{}{
wtAccountID: appMetadata.WTAccountID,
wtPendingInvite: pendingInvite,
},
}
_, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)).
PatchedUserRequest(patchedUserReq).
Execute()
if err != nil {
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode)
}
return nil return nil
} }
@ -283,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
} }
return parseAuthentikUser(*user) userData := parseAuthentikUser(*user)
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
@ -293,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err return nil, err
} }
accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID) userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -313,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) userData := parseAuthentikUser(user)
if err != nil { userData.AppMetadata.WTAccountID = accountID
return nil, err
}
users = append(users, userData) users = append(users, userData)
} }
@ -350,65 +311,16 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) userData := parseAuthentikUser(user)
if err != nil { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// CreateUser creates a new user in authentik Idp and sends an invitation. // CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
ctx, err := am.authenticationContext() return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
groupID, err := am.getUserGroupByName("netbird")
if err != nil {
return nil, err
}
defaultBoolValue := true
createUserRequest := api.UserRequest{
Email: &email,
Name: name,
IsActive: &defaultBoolValue,
Groups: []string{groupID},
Username: email,
Attributes: map[string]interface{}{
wtAccountID: accountID,
wtPendingInvite: &defaultBoolValue,
},
}
user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseAuthentikUser(*user)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -438,11 +350,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) users = append(users, parseAuthentikUser(user))
if err != nil {
return nil, err
}
users = append(users, userData)
} }
return users, nil return users, nil
@ -501,64 +409,10 @@ func (am *AuthentikManager) authenticationContext() (context.Context, error) {
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
} }
// getUserGroupByName retrieves the user group for assigning new users. func parseAuthentikUser(user api.User) *UserData {
// If the group is not found, a new group with the specified name will be created.
func (am *AuthentikManager) getUserGroupByName(name string) (string, error) {
ctx, err := am.authenticationContext()
if err != nil {
return "", err
}
groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if groupList != nil {
if len(groupList.Results) > 0 {
return groupList.Results[0].Pk, nil
}
}
createGroupRequest := api.GroupRequest{Name: name}
group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode)
}
return group.Pk, nil
}
func parseAuthentikUser(user api.User) (*UserData, error) {
var attributes struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
}
helper := JsonParser{}
buf, err := helper.Marshal(user.Attributes)
if err != nil {
return nil, err
}
err = helper.Unmarshal(buf, &attributes)
if err != nil {
return nil, err
}
return &UserData{ return &UserData{
Email: *user.Email, Email: *user.Email,
Name: user.Name, Name: user.Name,
ID: strconv.FormatInt(int64(user.Pk), 10), ID: strconv.FormatInt(int64(user.Pk), 10),
AppMetadata: AppMetadata{ }
WTAccountID: attributes.AccountID,
WTPendingInvite: &attributes.PendingInvite,
},
}, nil
} }

View File

@ -1,7 +1,6 @@
package idp package idp
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -11,18 +10,12 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
const ( const profileFields = "id,displayName,mail,userPrincipalName"
// azure extension properties template
wtAccountIDTpl = "extension_%s_wt_account_id"
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
profileFields = "id,displayName,mail,userPrincipalName"
extensionFields = "id,name,targetObjects"
)
// AzureManager azure manager client instance. // AzureManager azure manager client instance.
type AzureManager struct { type AzureManager struct {
@ -58,21 +51,6 @@ type AzureCredentials struct {
// azureProfile represents an azure user profile. // azureProfile represents an azure user profile.
type azureProfile map[string]any type azureProfile map[string]any
// passwordProfile represent authentication method for,
// newly created user profile.
type passwordProfile struct {
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
Password string `json:"password"`
}
// azureExtension represent custom attribute,
// that can be added to user objects in Azure Active Directory (AD).
type azureExtension struct {
Name string `json:"name"`
DataType string `json:"dataType"`
TargetObjects []string `json:"targetObjects"`
}
// NewAzureManager creates a new instance of the AzureManager. // NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) { func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
appMetrics: appMetrics, appMetrics: appMetrics,
} }
manager := &AzureManager{ return &AzureManager{
ObjectID: config.ObjectID, ObjectID: config.ObjectID,
ClientID: config.ClientID, ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint, GraphAPIEndpoint: config.GraphAPIEndpoint,
@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
credentials: credentials, credentials: credentials,
helper: helper, helper: helper,
appMetrics: appMetrics, appMetrics: appMetrics,
} }, nil
err := manager.configureAppMetadata()
if err != nil {
return nil, err
}
return manager, nil
} }
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure. // jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in azure AD Idp. // CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID) return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
body, err := am.post("users", payload)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
profile[wtAccountIDField] = accountID
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
profile[wtPendingInviteField] = true
return profile.userData(am.ClientID), nil
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users/"+userID, q) body, err := am.get("users/"+userID, q)
if err != nil { if err != nil {
@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
return nil, err return nil, err
} }
return profile.userData(am.ClientID), nil userData := profile.userData()
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users/"+email, q) body, err := am.get("users/"+email, q)
if err != nil { if err != nil {
@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
users := make([]*UserData, 0) users := make([]*UserData, 0)
users = append(users, profile.userData(am.ClientID)) users = append(users, profile.userData())
return users, nil return users, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
body, err := am.get("users", q) body, err := am.get("users", q)
if err != nil { if err != nil {
@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles.Value { for _, profile := range profiles.Value {
users = append(users, profile.userData(am.ClientID)) userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users", q) body, err := am.get("users", q)
if err != nil { if err != nil {
@ -384,67 +318,15 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value { for _, profile := range profiles.Value {
userData := profile.userData(am.ClientID) userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// UpdateUserAppMetadata updates user app metadata based on userID. // UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
data, err := am.helper.Marshal(map[string]any{
wtAccountIDField: appMetadata.WTAccountID,
wtPendingInviteField: appMetadata.WTPendingInvite,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating idp metadata for user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -454,7 +336,7 @@ func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Azure // DeleteUser from Azure.
func (am *AzureManager) DeleteUser(userID string) error { func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {
@ -491,81 +373,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil return nil
} }
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.get(resource, q)
if err != nil {
return nil, err
}
var extensions struct{ Value []azureExtension }
err = am.helper.Unmarshal(body, &extensions)
if err != nil {
return nil, err
}
return extensions.Value, nil
}
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
extension := azureExtension{
Name: name,
DataType: "string",
TargetObjects: []string{"User"},
}
payload, err := am.helper.Marshal(extension)
if err != nil {
return nil, err
}
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.post(resource, string(payload))
if err != nil {
return nil, err
}
var userExtension azureExtension
err = am.helper.Unmarshal(body, &userExtension)
if err != nil {
return nil, err
}
return &userExtension, nil
}
// configureAppMetadata sets up app metadata extensions if they do not exists.
func (am *AzureManager) configureAppMetadata() error {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
extensions, err := am.getUserExtensions()
if err != nil {
return err
}
// If the wt_account_id extension does not already exist, create it.
if !hasExtension(extensions, wtAccountIDField) {
_, err = am.createUserExtension(wtAccountID)
if err != nil {
return err
}
}
// If the wt_pending_invite extension does not already exist, create it.
if !hasExtension(extensions, wtPendingInviteField) {
_, err = am.createUserExtension(wtPendingInvite)
if err != nil {
return err
}
}
return nil
}
// get perform Get requests. // get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
@ -602,44 +409,8 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// post perform Post requests.
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile. // userData construct user data from keycloak profile.
func (ap azureProfile) userData(clientID string) *UserData { func (ap azureProfile) userData() *UserData {
id, ok := ap["id"].(string) id, ok := ap["id"].(string)
if !ok { if !ok {
id = "" id = ""
@ -655,66 +426,9 @@ func (ap azureProfile) userData(clientID string) *UserData {
name = "" name = ""
} }
accountIDField := extensionName(wtAccountIDTpl, clientID)
accountID, ok := ap[accountIDField].(string)
if !ok {
accountID = ""
}
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
pendingInvite, ok := ap[pendingInviteField].(bool)
if !ok {
pendingInvite = false
}
return &UserData{ return &UserData{
Email: email, Email: email,
Name: name, Name: name,
ID: id, ID: id,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
} }
} }
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
req := &azureProfile{
"accountEnabled": true,
"displayName": name,
"mailNickName": strings.Join(strings.Split(name, " "), ""),
"userPrincipalName": email,
"passwordProfile": passwordProfile{
ForceChangePasswordNextSignIn: true,
Password: GeneratePassword(8, 1, 1, 1),
},
wtAccountIDField: accountID,
wtPendingInviteField: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func extensionName(extensionTpl, clientID string) string {
clientID = strings.ReplaceAll(clientID, "-", "")
return fmt.Sprintf(extensionTpl, clientID)
}
// hasExtension checks whether a given extension by name,
// exists in an list of extensions.
func hasExtension(extensions []azureExtension, name string) bool {
for _, ext := range extensions {
if ext.Name == name {
return true
}
}
return false
}

View File

@ -8,15 +8,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type mockAzureCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestAzureJwtStillValid(t *testing.T) { func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct { type jwtStillValidTest struct {
name string name string
@ -124,206 +115,63 @@ func TestAzureAuthenticate(t *testing.T) {
} }
} }
func TestAzureUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &AzureManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestAzureProfile(t *testing.T) { func TestAzureProfile(t *testing.T) {
type azureProfileTest struct { type azureProfileTest struct {
name string name string
clientID string
invite bool invite bool
inputProfile azureProfile inputProfile azureProfile
expectedUserData UserData expectedUserData UserData
} }
azureProfileTestCase1 := azureProfileTest{ azureProfileTestCase1 := azureProfileTest{
name: "Good Request", name: "Good Request",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", invite: false,
invite: false,
inputProfile: azureProfile{ inputProfile: azureProfile{
"id": "test1", "id": "test1",
"displayName": "John Doe", "displayName": "John Doe",
"userPrincipalName": "test1@test.com", "userPrincipalName": "test1@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
Email: "test1@test.com", Email: "test1@test.com",
Name: "John Doe", Name: "John Doe",
ID: "test1", ID: "test1",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase2 := azureProfileTest{ azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID", name: "Missing User ID",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", invite: true,
invite: true,
inputProfile: azureProfile{ inputProfile: azureProfile{
"displayName": "John Doe", "displayName": "John Doe",
"userPrincipalName": "test2@test.com", "userPrincipalName": "test2@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
Email: "test2@test.com", Email: "test2@test.com",
Name: "John Doe", Name: "John Doe",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase3 := azureProfileTest{ azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name", name: "Missing User Name",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c", invite: false,
invite: false,
inputProfile: azureProfile{ inputProfile: azureProfile{
"id": "test3", "id": "test3",
"userPrincipalName": "test3@test.com", "userPrincipalName": "test3@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test3", ID: "test3",
Email: "test3@test.com", Email: "test3@test.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase4 := azureProfileTest{ for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
name: "Missing Extension Fields",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test4",
"displayName": "John Doe",
"userPrincipalName": "test4@test.com",
},
expectedUserData: UserData{
ID: "test4",
Name: "John Doe",
Email: "test4@test.com",
AppMetadata: AppMetadata{},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData(testCase.clientID) userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
}) })
} }
} }

View File

@ -5,15 +5,14 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option" "google.golang.org/api/option"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// GoogleWorkspaceManager Google Workspace manager client instance. // GoogleWorkspaceManager Google Workspace manager client instance.
@ -73,17 +72,13 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
service, err := admin.NewService(context.Background(), service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope), option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
option.WithCredentials(adminCredentials), option.WithCredentials(adminCredentials),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
return nil, err
}
return &GoogleWorkspaceManager{ return &GoogleWorkspaceManager{
usersService: service.Users, usersService: service.Users,
CustomerID: config.CustomerID, CustomerID: config.CustomerID,
@ -95,27 +90,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
metadata, err := gm.helper.Marshal(appMetadata)
if err != nil {
return err
}
user := &admin.User{
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": metadata,
},
}
_, err = gm.usersService.Update(userID, user).Do()
if err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil return nil
} }
@ -130,23 +105,23 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
gm.appMetrics.IDPMetrics().CountGetUserDataByID() gm.appMetrics.IDPMetrics().CountGetUserDataByID()
} }
return parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID) usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }
usersData := make([]*UserData, 0) usersData := make([]*UserData, 0)
for _, user := range usersList.Users { for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
if err != nil { userData.AppMetadata.WTAccountID = accountID
return nil, err
}
usersData = append(usersData, userData) usersData = append(usersData, userData)
} }
@ -168,61 +143,16 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users { for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
if err != nil { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// CreateUser creates a new user in Google Workspace and sends an invitation. // CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
invite := true return nil, fmt.Errorf("method CreateUser not implemented")
metadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
username := &admin.UserName{}
fields := strings.Fields(name)
if n := len(fields); n > 0 {
username.GivenName = strings.Join(fields[:n-1], " ")
username.FamilyName = fields[n-1]
}
payload, err := gm.helper.Marshal(metadata)
if err != nil {
return nil, err
}
user := &admin.User{
Name: username,
PrimaryEmail: email,
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": payload,
},
Password: GeneratePassword(8, 1, 1, 1),
}
user, err = gm.usersService.Insert(user).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountCreateUser()
}
return parseGoogleWorkspaceUser(user)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -237,13 +167,8 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
gm.appMetrics.IDPMetrics().CountGetUserByEmail() gm.appMetrics.IDPMetrics().CountGetUserByEmail()
} }
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
users := make([]*UserData, 0) users := make([]*UserData, 0)
users = append(users, userData) users = append(users, parseGoogleWorkspaceUser(user))
return users, nil return users, nil
} }
@ -281,8 +206,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err := google.CredentialsFromJSON( creds, err := google.CredentialsFromJSON(
context.Background(), context.Background(),
decodeKey, decodeKey,
admin.AdminDirectoryUserschemaScope, admin.AdminDirectoryUserReadonlyScope,
admin.AdminDirectoryUserScope,
) )
if err == nil { if err == nil {
// No need to fallback to the default Google credentials path // No need to fallback to the default Google credentials path
@ -294,8 +218,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err = google.FindDefaultCredentials( creds, err = google.FindDefaultCredentials(
context.Background(), context.Background(),
admin.AdminDirectoryUserschemaScope, admin.AdminDirectoryUserReadonlyScope,
admin.AdminDirectoryUserScope,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -304,62 +227,11 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
return creds, nil return creds, nil
} }
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
schemaList, err := service.Schemas.List(customerID).Do()
if err != nil {
return err
}
// checks if app_metadata schema is already created
for _, schema := range schemaList.Schemas {
if schema.SchemaName == "app_metadata" {
return nil
}
}
// create new app_metadata schema
appMetadataSchema := &admin.Schema{
SchemaName: "app_metadata",
Fields: []*admin.SchemaFieldSpec{
{
FieldName: "wt_account_id",
FieldType: "STRING",
MultiValued: false,
},
{
FieldName: "wt_pending_invite",
FieldType: "BOOL",
MultiValued: false,
},
},
}
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
if err != nil {
return err
}
return nil
}
// parseGoogleWorkspaceUser parse google user to UserData. // parseGoogleWorkspaceUser parse google user to UserData.
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) { func parseGoogleWorkspaceUser(user *admin.User) *UserData {
var appMetadata AppMetadata
// Get app metadata from custom schemas
if user.CustomSchemas != nil {
rawMessage := user.CustomSchemas["app_metadata"]
helper := JsonParser{}
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
return nil, err
}
}
return &UserData{ return &UserData{
ID: user.Id, ID: user.Id,
Email: user.PrimaryEmail, Email: user.PrimaryEmail,
Name: user.Name.FullName, Name: user.Name.FullName,
AppMetadata: appMetadata, }
}, nil
} }

View File

@ -9,6 +9,11 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const (
// UnsetAccountID is a special key to map users without an account ID
UnsetAccountID = "unset"
)
// Manager idp manager interface // Manager idp manager interface
type Manager interface { type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error

View File

@ -1,12 +1,10 @@
package idp package idp
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -18,11 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const (
wtAccountID = "wt_account_id"
wtPendingInvite = "wt_pending_invite"
)
// KeycloakManager keycloak manager client instance. // KeycloakManager keycloak manager client instance.
type KeycloakManager struct { type KeycloakManager struct {
adminEndpoint string adminEndpoint string
@ -51,28 +44,10 @@ type KeycloakCredentials struct {
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
// keycloakUserCredential describe the authentication method for,
// newly created user profile.
type keycloakUserCredential struct {
Type string `json:"type"`
Value string `json:"value"`
Temporary bool `json:"temporary"`
}
// keycloakUserAttributes holds additional user data fields. // keycloakUserAttributes holds additional user data fields.
type keycloakUserAttributes map[string][]string type keycloakUserAttributes map[string][]string
// createUserRequest is a user create request. // keycloakProfile represents a keycloak user profile response.
type keycloakCreateUserRequest struct {
Email string `json:"email"`
Username string `json:"username"`
Enabled bool `json:"enabled"`
EmailVerified bool `json:"emailVerified"`
Credentials []keycloakUserCredential `json:"credentials"`
Attributes keycloakUserAttributes `json:"attributes"`
}
// keycloakProfile represents an keycloak user profile response.
type keycloakProfile struct { type keycloakProfile struct {
ID string `json:"id"` ID string `json:"id"`
CreatedTimestamp int64 `json:"createdTimestamp"` CreatedTimestamp int64 `json:"createdTimestamp"`
@ -230,62 +205,8 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in keycloak Idp and sends an invite. // CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
jwtToken, err := km.credentials.Authenticate() return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
payload := strings.NewReader(payloadString)
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
locationHeader := resp.Header.Get("location")
userID, err := extractUserIDFromLocationHeader(locationHeader)
if err != nil {
return nil, err
}
return km.GetUserDataByID(userID, appMetadata)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -319,7 +240,7 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil) body, err := km.get("users/"+userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -338,12 +259,9 @@ func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadat
return profile.userData(), nil return profile.userData(), nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
q := url.Values{} profiles, err := km.fetchAllUserProfiles()
q.Add("q", wtAccountID+":"+accountID)
body, err := km.get("users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -352,15 +270,12 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAccount() km.appMetrics.IDPMetrics().CountGetAccount()
} }
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles { for _, profile := range profiles {
users = append(users, profile.userData()) userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -369,15 +284,7 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
totalUsers, err := km.totalUsersCount() profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -386,78 +293,17 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAllAccounts() km.appMetrics.IDPMetrics().CountGetAllAccounts()
} }
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles { for _, profile := range profiles {
userData := profile.userData() userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID)
if appMetadata.WTPendingInvite != nil {
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
} else {
attrs.Set(wtPendingInvite, "false")
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
data, err := km.helper.Marshal(map[string]any{
"attributes": attrs,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID)
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -467,7 +313,7 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Keycloack // DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(userID string) error { func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate() jwtToken, err := km.credentials.Authenticate()
if err != nil { if err != nil {
@ -475,7 +321,6 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
} }
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID)) reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil) req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil { if err != nil {
return err return err
@ -508,32 +353,27 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
return nil return nil
} }
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
attrs := keycloakUserAttributes{} totalUsers, err := km.totalUsersCount()
attrs.Set(wtAccountID, appMetadata.WTAccountID)
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
req := &keycloakCreateUserRequest{
Email: email,
Username: name,
Enabled: true,
EmailVerified: true,
Credentials: []keycloakUserCredential{
{
Type: "password",
Value: GeneratePassword(8, 1, 1, 1),
Temporary: false,
},
},
Attributes: attrs,
}
str, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
return string(str), nil q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
if err != nil {
return nil, err
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
return profiles, nil
} }
// get perform Get requests. // get perform Get requests.
@ -588,53 +428,11 @@ func (km *KeycloakManager) totalUsersCount() (*int, error) {
return &count, nil return &count, nil
} }
// extractUserIDFromLocationHeader extracts the user ID from the location,
// header once the user is created successfully
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
userURL, err := url.Parse(locationHeader)
if err != nil {
return "", err
}
return path.Base(userURL.Path), nil
}
// userData construct user data from keycloak profile. // userData construct user data from keycloak profile.
func (kp keycloakProfile) userData() *UserData { func (kp keycloakProfile) userData() *UserData {
accountID := kp.Attributes.Get(wtAccountID)
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
if err != nil {
pendingInvite = false
}
return &UserData{ return &UserData{
Email: kp.Email, Email: kp.Email,
Name: kp.Username, Name: kp.Username,
ID: kp.ID, ID: kp.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
} }
} }
// Set sets the key to value. It replaces any existing
// values.
func (ka keycloakUserAttributes) Set(key, value string) {
ka[key] = []string{value}
}
// Get returns the first value associated with the given key.
// If there are no values associated with the key, Get returns
// the empty string.
func (ka keycloakUserAttributes) Get(key string) string {
if ka == nil {
return ""
}
values := ka[key]
if len(values) == 0 {
return ""
}
return values[0]
}

View File

@ -84,15 +84,6 @@ func TestNewKeycloakManager(t *testing.T) {
} }
} }
type mockKeycloakCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestKeycloakRequestJWTToken(t *testing.T) { func TestKeycloakRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct { type requestJWTTokenTest struct {
@ -316,108 +307,3 @@ func TestKeycloakAuthenticate(t *testing.T) {
}) })
} }
} }
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &KeycloakManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}

View File

@ -8,9 +8,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/okta/okta-sdk-golang/v2/okta" "github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// OktaManager okta manager client instance. // OktaManager okta manager client instance.
@ -76,11 +76,6 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
return nil, err return nil, err
} }
err = updateUserProfileSchema(client)
if err != nil {
return nil, err
}
credentials := &OktaCredentials{ credentials := &OktaCredentials{
clientConfig: config, clientConfig: config,
httpClient: httpClient, httpClient: httpClient,
@ -103,49 +98,8 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in okta Idp and sends an invitation. // CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
var ( return nil, fmt.Errorf("method CreateUser not implemented")
sendEmail = true
activate = true
userProfile = okta.UserProfile{
"email": email,
"login": email,
wtAccountID: accountID,
wtPendingInvite: true,
}
)
fields := strings.Fields(name)
if n := len(fields); n > 0 {
userProfile["firstName"] = strings.Join(fields[:n-1], " ")
userProfile["lastName"] = fields[n-1]
}
user, resp, err := om.client.User.CreateUser(context.Background(),
okta.CreateUserRequest{
Profile: &userProfile,
},
&query.Params{
Activate: &activate,
SendEmail: &sendEmail,
},
)
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseOktaUser(user)
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
@ -166,7 +120,13 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
} }
return parseOktaUser(user) userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -200,8 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
search := fmt.Sprintf("profile.wt_account_id eq %q", accountID) users, resp, err := om.client.User.ListUsers(context.Background(), nil)
users, resp, err := om.client.User.ListUsers(context.Background(), &query.Params{Search: search})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -223,6 +182,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
userData.AppMetadata.WTAccountID = accountID
list = append(list, userData) list = append(list, userData)
} }
@ -256,13 +216,7 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, err return nil, err
} }
accountID := userData.AppMetadata.WTAccountID indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
@ -270,46 +224,6 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
profile := *user.Profile
if appMetadata.WTPendingInvite != nil {
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
}
if appMetadata.WTAccountID != "" {
profile[wtAccountID] = appMetadata.WTAccountID
}
user.Profile = &profile
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
if err != nil {
fmt.Println(err.Error())
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -341,60 +255,12 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil return nil
} }
// updateUserProfileSchema updates the Okta user schema to include custom fields,
// wt_account_id and wt_pending_invite.
func updateUserProfileSchema(client *okta.Client) error {
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
_, resp, err := client.UserSchema.UpdateUserProfile(
context.Background(),
"default",
okta.UserSchema{
Definitions: &okta.UserSchemaDefinitions{
Custom: &okta.UserSchemaPublic{
Id: "#custom",
Type: "object",
Properties: map[string]*okta.UserSchemaAttribute{
wtAccountID: {
MaxLength: 100,
MinLength: 1,
Required: new(bool),
Scope: "NONE",
Title: "Wt Account Id",
Type: "string",
Permissions: userPermissions,
},
wtPendingInvite: {
Required: new(bool),
Scope: "NONE",
Title: "Wt Pending Invite",
Type: "boolean",
Permissions: userPermissions,
},
},
},
},
})
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unable to update user profile schema, statusCode %d", resp.StatusCode)
}
return nil
}
// parseOktaUserToUserData parse okta user to UserData. // parseOktaUserToUserData parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) { func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct { var oktaUser struct {
Email string `json:"email"` Email string `json:"email"`
FirstName string `json:"firstName"` FirstName string `json:"firstName"`
LastName string `json:"lastName"` LastName string `json:"lastName"`
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
} }
if user == nil { if user == nil {
@ -418,9 +284,5 @@ func parseOktaUser(user *okta.User) (*UserData, error) {
Email: oktaUser.Email, Email: oktaUser.Email,
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "), Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
ID: user.Id, ID: user.Id,
AppMetadata: AppMetadata{
WTAccountID: oktaUser.AccountID,
WTPendingInvite: &oktaUser.PendingInvite,
},
}, nil }, nil
} }

View File

@ -1,31 +1,28 @@
package idp package idp
import ( import (
"testing"
"github.com/okta/okta-sdk-golang/v2/okta" "github.com/okta/okta-sdk-golang/v2/okta"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing"
) )
func TestParseOktaUser(t *testing.T) { func TestParseOktaUser(t *testing.T) {
type parseOktaUserTest struct { type parseOktaUserTest struct {
name string name string
invite bool
inputProfile *okta.User inputProfile *okta.User
expectedUserData *UserData expectedUserData *UserData
assertErrFunc assert.ErrorAssertionFunc assertErrFunc assert.ErrorAssertionFunc
} }
parseOktaTestCase1 := parseOktaUserTest{ parseOktaTestCase1 := parseOktaUserTest{
name: "Good Request", name: "Good Request",
invite: true,
inputProfile: &okta.User{ inputProfile: &okta.User{
Id: "123", Id: "123",
Profile: &okta.UserProfile{ Profile: &okta.UserProfile{
"email": "test@example.com", "email": "test@example.com",
"firstName": "John", "firstName": "John",
"lastName": "Doe", "lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": true,
}, },
}, },
expectedUserData: &UserData{ expectedUserData: &UserData{
@ -41,36 +38,17 @@ func TestParseOktaUser(t *testing.T) {
parseOktaTestCase2 := parseOktaUserTest{ parseOktaTestCase2 := parseOktaUserTest{
name: "Invalid okta user", name: "Invalid okta user",
invite: true,
inputProfile: nil, inputProfile: nil,
expectedUserData: nil, expectedUserData: nil,
assertErrFunc: assert.Error, assertErrFunc: assert.Error,
} }
parseOktaTestCase3 := parseOktaUserTest{ for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
name: "Invalid pending invite type",
invite: false,
inputProfile: &okta.User{
Id: "123",
Profile: &okta.UserProfile{
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": "true",
},
},
expectedUserData: nil,
assertErrFunc: assert.Error,
}
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2, parseOktaTestCase3} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
userData, err := parseOktaUser(testCase.inputProfile) userData, err := parseOktaUser(testCase.inputProfile)
testCase.assertErrFunc(t, err, testCase.assertErrFunc) testCase.assertErrFunc(t, err, testCase.assertErrFunc)
if err == nil { if err == nil {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match") assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
} }
}) })
@ -83,13 +61,5 @@ func userDataEqual(a, b *UserData) bool {
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID { if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
return false return false
} }
if a.AppMetadata.WTAccountID != b.AppMetadata.WTAccountID {
return false
}
if a.AppMetadata.WTPendingInvite != nil && b.AppMetadata.WTPendingInvite != nil &&
*a.AppMetadata.WTPendingInvite != *b.AppMetadata.WTPendingInvite {
return false
}
return true return true
} }

View File

@ -1,13 +1,10 @@
package idp package idp
import ( import (
"encoding/base64"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -68,12 +65,6 @@ type zitadelUser struct {
type zitadelAttributes map[string][]map[string]any type zitadelAttributes map[string][]map[string]any
// zitadelMetadata holds additional user data.
type zitadelMetadata struct {
Key string `json:"key"`
Value string `json:"value"`
}
// zitadelProfile represents an zitadel user profile response. // zitadelProfile represents an zitadel user profile response.
type zitadelProfile struct { type zitadelProfile struct {
ID string `json:"id"` ID string `json:"id"`
@ -82,7 +73,6 @@ type zitadelProfile struct {
PreferredLoginName string `json:"preferredLoginName"` PreferredLoginName string `json:"preferredLoginName"`
LoginNames []string `json:"loginNames"` LoginNames []string `json:"loginNames"`
Human *zitadelUser `json:"human"` Human *zitadelUser `json:"human"`
Metadata []zitadelMetadata
} }
// NewZitadelManager creates a new instance of the ZitadelManager. // NewZitadelManager creates a new instance of the ZitadelManager.
@ -235,42 +225,8 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in zitadel Idp and sends an invite. // CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) {
payload, err := buildZitadelCreateUserRequestPayload(email, name) return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", payload)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var result struct {
UserID string `json:"userId"`
}
err = zm.helper.Unmarshal(body, &result)
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
// Add metadata to new user
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
if err != nil {
return nil, err
}
return zm.GetUserDataByID(result.UserID, appMetadata)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -308,12 +264,6 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles.Result { for _, profile := range profiles.Result {
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
users = append(users, profile.userData()) users = append(users, profile.userData())
} }
@ -337,18 +287,15 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
return nil, err return nil, err
} }
metadata, err := zm.getUserMetadata(userID) userData := profile.User.userData()
if err != nil { userData.AppMetadata = appMetadata
return nil, err
}
profile.User.Metadata = metadata
return profile.User.userData(), nil return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
accounts, err := zm.GetAllAccounts() body, err := zm.post("users/_search", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -357,7 +304,21 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
zm.appMetrics.IDPMetrics().CountGetAccount() zm.appMetrics.IDPMetrics().CountGetAccount()
} }
return accounts[accountID], nil var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
} }
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
@ -380,22 +341,8 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Result { for _, profile := range profiles.Result {
// fetch user metadata
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
userData := profile.userData() userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
@ -403,42 +350,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded. // Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
if appMetadata.WTPendingInvite == nil {
appMetadata.WTPendingInvite = new(bool)
}
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
metadata := zitadelAttributes{
"metadata": {
{
"key": wtAccountID,
"value": wtAccountIDValue,
},
{
"key": wtPendingInvite,
"value": wtPendingInviteValue,
},
},
}
payload, err := zm.helper.Marshal(metadata)
if err != nil {
return err
}
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil return nil
} }
@ -460,24 +372,6 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
} }
return nil return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
body, err := zm.post(resource, "")
if err != nil {
return nil, err
}
var metadata struct{ Result []zitadelMetadata }
err = zm.helper.Unmarshal(body, &metadata)
if err != nil {
return nil, err
}
return metadata.Result, nil
} }
// post perform Post requests. // post perform Post requests.
@ -517,38 +411,7 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
} }
// delete perform Delete requests. // delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error { func (zm *ZitadelManager) delete(_ string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
}
return nil return nil
} }
@ -588,38 +451,13 @@ func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// value returns string represented by the base64 string value.
func (zm zitadelMetadata) value() string {
value, err := base64.StdEncoding.DecodeString(zm.Value)
if err != nil {
return ""
}
return string(value)
}
// userData construct user data from zitadel profile. // userData construct user data from zitadel profile.
func (zp zitadelProfile) userData() *UserData { func (zp zitadelProfile) userData() *UserData {
var ( var (
email string email string
name string name string
wtAccountIDValue string
wtPendingInviteValue bool
) )
for _, metadata := range zp.Metadata {
if metadata.Key == wtAccountID {
wtAccountIDValue = metadata.value()
}
if metadata.Key == wtPendingInvite {
value, err := strconv.ParseBool(metadata.value())
if err == nil {
wtPendingInviteValue = value
}
}
}
// Obtain the email for the human account and the login name, // Obtain the email for the human account and the login name,
// for the machine account. // for the machine account.
if zp.Human != nil { if zp.Human != nil {
@ -636,39 +474,5 @@ func (zp zitadelProfile) userData() *UserData {
Email: email, Email: email,
Name: name, Name: name,
ID: zp.ID, ID: zp.ID,
AppMetadata: AppMetadata{
WTAccountID: wtAccountIDValue,
WTPendingInvite: &wtPendingInviteValue,
},
} }
} }
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
var firstName, lastName string
words := strings.Fields(name)
if n := len(words); n > 0 {
firstName = strings.Join(words[:n-1], " ")
lastName = words[n-1]
}
req := &zitadelUser{
UserName: name,
Profile: zitadelUserInfo{
FirstName: strings.TrimSpace(firstName),
LastName: strings.TrimSpace(lastName),
DisplayName: name,
},
Email: zitadelEmail{
Email: email,
IsEmailVerified: false,
},
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}

View File

@ -7,9 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
func TestNewZitadelManager(t *testing.T) { func TestNewZitadelManager(t *testing.T) {
@ -63,15 +64,6 @@ func TestNewZitadelManager(t *testing.T) {
} }
} }
type mockZitadelCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestZitadelRequestJWTToken(t *testing.T) { func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct { type requestJWTTokenTest struct {
@ -296,98 +288,6 @@ func TestZitadelAuthenticate(t *testing.T) {
} }
} }
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &ZitadelManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestZitadelProfile(t *testing.T) { func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct { type azureProfileTest struct {
name string name string
@ -418,16 +318,6 @@ func TestZitadelProfile(t *testing.T) {
IsEmailVerified: true, IsEmailVerified: true,
}, },
}, },
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "ZmFsc2U=",
},
},
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test1", ID: "test1",
@ -451,16 +341,6 @@ func TestZitadelProfile(t *testing.T) {
"machine", "machine",
}, },
Human: nil, Human: nil,
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "dHJ1ZQ==",
},
},
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test2", ID: "test2",
@ -480,8 +360,6 @@ func TestZitadelProfile(t *testing.T) {
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
}) })
} }
} }