mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-11 16:38:27 +01:00
Optimize Cache and IDP Management (#1147)
This pull request modifies the IdP and cache manager(s) to prevent the sending of app metadata to the upstream IDP on self-hosted instances. As a result, the IdP will now load all users from the IdP without filtering based on accountID. We disable user invites as the administrator's own IDP system manages them.
This commit is contained in:
parent
a952e7c72f
commit
e26ec0b937
@ -988,6 +988,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the Identity Provider does not support writing AppMetadata,
|
||||
// in cases like this, we expect it to return all users in an "unset" field.
|
||||
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
|
||||
// update their AppMetadata with the AccountID.
|
||||
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
|
||||
for _, user := range unsetData {
|
||||
accountID, err := am.Store.GetAccountByUser(user.ID)
|
||||
if err == nil {
|
||||
data := userData[accountID.Id]
|
||||
if data == nil {
|
||||
data = make([]*idp.UserData, 0, 1)
|
||||
}
|
||||
|
||||
user.AppMetadata.WTAccountID = accountID.Id
|
||||
|
||||
userData[accountID.Id] = append(data, user)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(userData, idp.UnsetAccountID)
|
||||
|
||||
for accountID, users := range userData {
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
if err != nil {
|
||||
|
@ -210,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
ctx, err := am.authenticationContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pendingInvite bool
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
pendingInvite = *appMetadata.WTPendingInvite
|
||||
}
|
||||
|
||||
patchedUserReq := api.PatchedUserRequest{
|
||||
Attributes: map[string]interface{}{
|
||||
wtAccountID: appMetadata.WTAccountID,
|
||||
wtPendingInvite: pendingInvite,
|
||||
},
|
||||
}
|
||||
_, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)).
|
||||
PatchedUserRequest(patchedUserReq).
|
||||
Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -283,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return parseAuthentikUser(*user)
|
||||
userData := parseAuthentikUser(*user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
@ -293,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID)
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute()
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -313,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
userData, err := parseAuthentikUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData := parseAuthentikUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
@ -350,65 +311,16 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
userData, err := parseAuthentikUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
userData := parseAuthentikUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupID, err := am.getUserGroupByName("netbird")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defaultBoolValue := true
|
||||
createUserRequest := api.UserRequest{
|
||||
Email: &email,
|
||||
Name: name,
|
||||
IsActive: &defaultBoolValue,
|
||||
Groups: []string{groupID},
|
||||
Username: email,
|
||||
Attributes: map[string]interface{}{
|
||||
wtAccountID: accountID,
|
||||
wtPendingInvite: &defaultBoolValue,
|
||||
},
|
||||
}
|
||||
user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return parseAuthentikUser(*user)
|
||||
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@ -438,11 +350,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
userData, err := parseAuthentikUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, userData)
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@ -501,64 +409,10 @@ func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
|
||||
}
|
||||
|
||||
// getUserGroupByName retrieves the user group for assigning new users.
|
||||
// If the group is not found, a new group with the specified name will be created.
|
||||
func (am *AuthentikManager) getUserGroupByName(name string) (string, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if groupList != nil {
|
||||
if len(groupList.Results) > 0 {
|
||||
return groupList.Results[0].Pk, nil
|
||||
}
|
||||
}
|
||||
|
||||
createGroupRequest := api.GroupRequest{Name: name}
|
||||
group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return group.Pk, nil
|
||||
}
|
||||
|
||||
func parseAuthentikUser(user api.User) (*UserData, error) {
|
||||
var attributes struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
buf, err := helper.Marshal(user.Attributes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = helper.Unmarshal(buf, &attributes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func parseAuthentikUser(user api.User) *UserData {
|
||||
return &UserData{
|
||||
Email: *user.Email,
|
||||
Name: user.Name,
|
||||
ID: strconv.FormatInt(int64(user.Pk), 10),
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: attributes.AccountID,
|
||||
WTPendingInvite: &attributes.PendingInvite,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -11,18 +10,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
// azure extension properties template
|
||||
wtAccountIDTpl = "extension_%s_wt_account_id"
|
||||
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
|
||||
|
||||
profileFields = "id,displayName,mail,userPrincipalName"
|
||||
extensionFields = "id,name,targetObjects"
|
||||
)
|
||||
const profileFields = "id,displayName,mail,userPrincipalName"
|
||||
|
||||
// AzureManager azure manager client instance.
|
||||
type AzureManager struct {
|
||||
@ -58,21 +51,6 @@ type AzureCredentials struct {
|
||||
// azureProfile represents an azure user profile.
|
||||
type azureProfile map[string]any
|
||||
|
||||
// passwordProfile represent authentication method for,
|
||||
// newly created user profile.
|
||||
type passwordProfile struct {
|
||||
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// azureExtension represent custom attribute,
|
||||
// that can be added to user objects in Azure Active Directory (AD).
|
||||
type azureExtension struct {
|
||||
Name string `json:"name"`
|
||||
DataType string `json:"dataType"`
|
||||
TargetObjects []string `json:"targetObjects"`
|
||||
}
|
||||
|
||||
// NewAzureManager creates a new instance of the AzureManager.
|
||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
manager := &AzureManager{
|
||||
return &AzureManager{
|
||||
ObjectID: config.ObjectID,
|
||||
ClientID: config.ClientID,
|
||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||
@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
err := manager.configureAppMetadata()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||
@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := am.post("users", payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
profile[wtAccountIDField] = accountID
|
||||
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
profile[wtPendingInviteField] = true
|
||||
|
||||
return profile.userData(am.ClientID), nil
|
||||
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", selectFields)
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+userID, q)
|
||||
if err != nil {
|
||||
@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profile.userData(am.ClientID), nil
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", selectFields)
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+email, q)
|
||||
if err != nil {
|
||||
@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, profile.userData(am.ClientID))
|
||||
users = append(users, profile.userData())
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", selectFields)
|
||||
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users", q)
|
||||
if err != nil {
|
||||
@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Value {
|
||||
users = append(users, profile.userData(am.ClientID))
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", selectFields)
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users", q)
|
||||
if err != nil {
|
||||
@ -384,67 +318,15 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles.Value {
|
||||
userData := profile.userData(am.ClientID)
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
|
||||
data, err := am.helper.Marshal(map[string]any{
|
||||
wtAccountIDField: appMetadata.WTAccountID,
|
||||
wtPendingInviteField: appMetadata.WTPendingInvite,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
|
||||
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating idp metadata for user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -454,7 +336,7 @@ func (am *AzureManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
@ -491,81 +373,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", extensionFields)
|
||||
|
||||
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
||||
body, err := am.get(resource, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var extensions struct{ Value []azureExtension }
|
||||
err = am.helper.Unmarshal(body, &extensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return extensions.Value, nil
|
||||
}
|
||||
|
||||
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
|
||||
extension := azureExtension{
|
||||
Name: name,
|
||||
DataType: "string",
|
||||
TargetObjects: []string{"User"},
|
||||
}
|
||||
|
||||
payload, err := am.helper.Marshal(extension)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
||||
body, err := am.post(resource, string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userExtension azureExtension
|
||||
err = am.helper.Unmarshal(body, &userExtension)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &userExtension, nil
|
||||
}
|
||||
|
||||
// configureAppMetadata sets up app metadata extensions if they do not exists.
|
||||
func (am *AzureManager) configureAppMetadata() error {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
||||
|
||||
extensions, err := am.getUserExtensions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the wt_account_id extension does not already exist, create it.
|
||||
if !hasExtension(extensions, wtAccountIDField) {
|
||||
_, err = am.createUserExtension(wtAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the wt_pending_invite extension does not already exist, create it.
|
||||
if !hasExtension(extensions, wtPendingInviteField) {
|
||||
_, err = am.createUserExtension(wtPendingInvite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
@ -602,44 +409,8 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// post perform Post requests.
|
||||
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
|
||||
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (ap azureProfile) userData(clientID string) *UserData {
|
||||
func (ap azureProfile) userData() *UserData {
|
||||
id, ok := ap["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
@ -655,66 +426,9 @@ func (ap azureProfile) userData(clientID string) *UserData {
|
||||
name = ""
|
||||
}
|
||||
|
||||
accountIDField := extensionName(wtAccountIDTpl, clientID)
|
||||
accountID, ok := ap[accountIDField].(string)
|
||||
if !ok {
|
||||
accountID = ""
|
||||
}
|
||||
|
||||
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
||||
pendingInvite, ok := ap[pendingInviteField].(bool)
|
||||
if !ok {
|
||||
pendingInvite = false
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: id,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pendingInvite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
|
||||
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
|
||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
||||
|
||||
req := &azureProfile{
|
||||
"accountEnabled": true,
|
||||
"displayName": name,
|
||||
"mailNickName": strings.Join(strings.Split(name, " "), ""),
|
||||
"userPrincipalName": email,
|
||||
"passwordProfile": passwordProfile{
|
||||
ForceChangePasswordNextSignIn: true,
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
},
|
||||
wtAccountIDField: accountID,
|
||||
wtPendingInviteField: true,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func extensionName(extensionTpl, clientID string) string {
|
||||
clientID = strings.ReplaceAll(clientID, "-", "")
|
||||
return fmt.Sprintf(extensionTpl, clientID)
|
||||
}
|
||||
|
||||
// hasExtension checks whether a given extension by name,
|
||||
// exists in an list of extensions.
|
||||
func hasExtension(extensions []azureExtension, name string) bool {
|
||||
for _, ext := range extensions {
|
||||
if ext.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -8,15 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockAzureCredentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func TestAzureJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
@ -124,206 +115,63 @@ func TestAzureAuthenticate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureUpdateUserAppMetadata(t *testing.T) {
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAzureCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAzureCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
managerCreds: &mockAzureCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAzureCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAzureCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
reqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
|
||||
manager := &AzureManager{
|
||||
httpClient: &reqClient,
|
||||
credentials: testCase.managerCreds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
clientID string
|
||||
invite bool
|
||||
inputProfile azureProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "Good Request",
|
||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||
invite: false,
|
||||
name: "Good Request",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test1",
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test1@test.com",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test1@test.com",
|
||||
Name: "John Doe",
|
||||
ID: "test1",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Missing User ID",
|
||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||
invite: true,
|
||||
name: "Missing User ID",
|
||||
invite: true,
|
||||
inputProfile: azureProfile{
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test2@test.com",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test2@test.com",
|
||||
Name: "John Doe",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase3 := azureProfileTest{
|
||||
name: "Missing User Name",
|
||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||
invite: false,
|
||||
name: "Missing User Name",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test3",
|
||||
"userPrincipalName": "test3@test.com",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test3",
|
||||
Email: "test3@test.com",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase4 := azureProfileTest{
|
||||
name: "Missing Extension Fields",
|
||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test4",
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test4@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test4",
|
||||
Name: "John Doe",
|
||||
Email: "test4@test.com",
|
||||
AppMetadata: AppMetadata{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData(testCase.clientID)
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -5,15 +5,14 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// GoogleWorkspaceManager Google Workspace manager client instance.
|
||||
@ -73,17 +72,13 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
service, err := admin.NewService(context.Background(),
|
||||
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope),
|
||||
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
|
||||
option.WithCredentials(adminCredentials),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GoogleWorkspaceManager{
|
||||
usersService: service.Users,
|
||||
CustomerID: config.CustomerID,
|
||||
@ -95,27 +90,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
metadata, err := gm.helper.Marshal(appMetadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user := &admin.User{
|
||||
CustomSchemas: map[string]googleapi.RawMessage{
|
||||
"app_metadata": metadata,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = gm.usersService.Update(userID, user).Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -130,23 +105,23 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
|
||||
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
return parseGoogleWorkspaceUser(user)
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID)
|
||||
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
|
||||
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
for _, user := range usersList.Users {
|
||||
userData, err := parseGoogleWorkspaceUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
usersData = append(usersData, userData)
|
||||
}
|
||||
@ -168,61 +143,16 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range usersList.Users {
|
||||
userData, err := parseGoogleWorkspaceUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
invite := true
|
||||
metadata := AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
}
|
||||
|
||||
username := &admin.UserName{}
|
||||
fields := strings.Fields(name)
|
||||
if n := len(fields); n > 0 {
|
||||
username.GivenName = strings.Join(fields[:n-1], " ")
|
||||
username.FamilyName = fields[n-1]
|
||||
}
|
||||
|
||||
payload, err := gm.helper.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &admin.User{
|
||||
Name: username,
|
||||
PrimaryEmail: email,
|
||||
CustomSchemas: map[string]googleapi.RawMessage{
|
||||
"app_metadata": payload,
|
||||
},
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
}
|
||||
user, err = gm.usersService.Insert(user).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
return parseGoogleWorkspaceUser(user)
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@ -237,13 +167,8 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
|
||||
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
userData, err := parseGoogleWorkspaceUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, userData)
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
|
||||
return users, nil
|
||||
}
|
||||
@ -281,8 +206,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
|
||||
creds, err := google.CredentialsFromJSON(
|
||||
context.Background(),
|
||||
decodeKey,
|
||||
admin.AdminDirectoryUserschemaScope,
|
||||
admin.AdminDirectoryUserScope,
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err == nil {
|
||||
// No need to fallback to the default Google credentials path
|
||||
@ -294,8 +218,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
admin.AdminDirectoryUserschemaScope,
|
||||
admin.AdminDirectoryUserScope,
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -304,62 +227,11 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
|
||||
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
|
||||
schemaList, err := service.Schemas.List(customerID).Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// checks if app_metadata schema is already created
|
||||
for _, schema := range schemaList.Schemas {
|
||||
if schema.SchemaName == "app_metadata" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// create new app_metadata schema
|
||||
appMetadataSchema := &admin.Schema{
|
||||
SchemaName: "app_metadata",
|
||||
Fields: []*admin.SchemaFieldSpec{
|
||||
{
|
||||
FieldName: "wt_account_id",
|
||||
FieldType: "STRING",
|
||||
MultiValued: false,
|
||||
},
|
||||
{
|
||||
FieldName: "wt_pending_invite",
|
||||
FieldType: "BOOL",
|
||||
MultiValued: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseGoogleWorkspaceUser parse google user to UserData.
|
||||
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) {
|
||||
var appMetadata AppMetadata
|
||||
|
||||
// Get app metadata from custom schemas
|
||||
if user.CustomSchemas != nil {
|
||||
rawMessage := user.CustomSchemas["app_metadata"]
|
||||
helper := JsonParser{}
|
||||
|
||||
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
|
||||
return &UserData{
|
||||
ID: user.Id,
|
||||
Email: user.PrimaryEmail,
|
||||
Name: user.Name.FullName,
|
||||
AppMetadata: appMetadata,
|
||||
}, nil
|
||||
ID: user.Id,
|
||||
Email: user.PrimaryEmail,
|
||||
Name: user.Name.FullName,
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnsetAccountID is a special key to map users without an account ID
|
||||
UnsetAccountID = "unset"
|
||||
)
|
||||
|
||||
// Manager idp manager interface
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
|
||||
|
@ -1,12 +1,10 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -18,11 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
wtAccountID = "wt_account_id"
|
||||
wtPendingInvite = "wt_pending_invite"
|
||||
)
|
||||
|
||||
// KeycloakManager keycloak manager client instance.
|
||||
type KeycloakManager struct {
|
||||
adminEndpoint string
|
||||
@ -51,28 +44,10 @@ type KeycloakCredentials struct {
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// keycloakUserCredential describe the authentication method for,
|
||||
// newly created user profile.
|
||||
type keycloakUserCredential struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Temporary bool `json:"temporary"`
|
||||
}
|
||||
|
||||
// keycloakUserAttributes holds additional user data fields.
|
||||
type keycloakUserAttributes map[string][]string
|
||||
|
||||
// createUserRequest is a user create request.
|
||||
type keycloakCreateUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Enabled bool `json:"enabled"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Credentials []keycloakUserCredential `json:"credentials"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// keycloakProfile represents an keycloak user profile response.
|
||||
// keycloakProfile represents a keycloak user profile response.
|
||||
type keycloakProfile struct {
|
||||
ID string `json:"id"`
|
||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||
@ -230,62 +205,8 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invite := true
|
||||
appMetadata := AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
}
|
||||
|
||||
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
|
||||
payload := strings.NewReader(payloadString)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
locationHeader := resp.Header.Get("location")
|
||||
userID, err := extractUserIDFromLocationHeader(locationHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return km.GetUserDataByID(userID, appMetadata)
|
||||
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@ -319,7 +240,7 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get("users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -338,12 +259,9 @@ func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadat
|
||||
return profile.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("q", wtAccountID+":"+accountID)
|
||||
|
||||
body, err := km.get("users", q)
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -352,15 +270,12 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@ -369,15 +284,7 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -386,78 +293,17 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||
} else {
|
||||
attrs.Set(wtPendingInvite, "false")
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
|
||||
data, err := km.helper.Marshal(map[string]any{
|
||||
"attributes": attrs,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -467,7 +313,7 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloack
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
@ -475,7 +321,6 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -508,32 +353,27 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||
|
||||
req := &keycloakCreateUserRequest{
|
||||
Email: email,
|
||||
Username: name,
|
||||
Enabled: true,
|
||||
EmailVerified: true,
|
||||
Credentials: []keycloakUserCredential{
|
||||
{
|
||||
Type: "password",
|
||||
Value: GeneratePassword(8, 1, 1, 1),
|
||||
Temporary: false,
|
||||
},
|
||||
},
|
||||
Attributes: attrs,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
@ -588,53 +428,11 @@ func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||
return &count, nil
|
||||
}
|
||||
|
||||
// extractUserIDFromLocationHeader extracts the user ID from the location,
|
||||
// header once the user is created successfully
|
||||
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
|
||||
userURL, err := url.Parse(locationHeader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return path.Base(userURL.Path), nil
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (kp keycloakProfile) userData() *UserData {
|
||||
accountID := kp.Attributes.Get(wtAccountID)
|
||||
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
|
||||
if err != nil {
|
||||
pendingInvite = false
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: kp.Email,
|
||||
Name: kp.Username,
|
||||
ID: kp.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pendingInvite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing
|
||||
// values.
|
||||
func (ka keycloakUserAttributes) Set(key, value string) {
|
||||
ka[key] = []string{value}
|
||||
}
|
||||
|
||||
// Get returns the first value associated with the given key.
|
||||
// If there are no values associated with the key, Get returns
|
||||
// the empty string.
|
||||
func (ka keycloakUserAttributes) Get(key string) string {
|
||||
if ka == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
values := ka[key]
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
|
@ -84,15 +84,6 @@ func TestNewKeycloakManager(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockKeycloakCredentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
@ -316,108 +307,3 @@ func TestKeycloakAuthenticate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
reqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
|
||||
manager := &KeycloakManager{
|
||||
httpClient: &reqClient,
|
||||
credentials: testCase.managerCreds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// OktaManager okta manager client instance.
|
||||
@ -76,11 +76,6 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updateUserProfileSchema(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials := &OktaCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
@ -103,49 +98,8 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
var (
|
||||
sendEmail = true
|
||||
activate = true
|
||||
userProfile = okta.UserProfile{
|
||||
"email": email,
|
||||
"login": email,
|
||||
wtAccountID: accountID,
|
||||
wtPendingInvite: true,
|
||||
}
|
||||
)
|
||||
|
||||
fields := strings.Fields(name)
|
||||
if n := len(fields); n > 0 {
|
||||
userProfile["firstName"] = strings.Join(fields[:n-1], " ")
|
||||
userProfile["lastName"] = fields[n-1]
|
||||
}
|
||||
|
||||
user, resp, err := om.client.User.CreateUser(context.Background(),
|
||||
okta.CreateUserRequest{
|
||||
Profile: &userProfile,
|
||||
},
|
||||
&query.Params{
|
||||
Activate: &activate,
|
||||
SendEmail: &sendEmail,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return parseOktaUser(user)
|
||||
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
@ -166,7 +120,13 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return parseOktaUser(user)
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@ -200,8 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
search := fmt.Sprintf("profile.wt_account_id eq %q", accountID)
|
||||
users, resp, err := om.client.User.ListUsers(context.Background(), &query.Params{Search: search})
|
||||
users, resp, err := om.client.User.ListUsers(context.Background(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -223,6 +182,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
list = append(list, userData)
|
||||
}
|
||||
@ -256,13 +216,7 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
@ -270,46 +224,6 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
profile := *user.Profile
|
||||
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
|
||||
}
|
||||
|
||||
if appMetadata.WTAccountID != "" {
|
||||
profile[wtAccountID] = appMetadata.WTAccountID
|
||||
}
|
||||
|
||||
user.Profile = &profile
|
||||
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -341,60 +255,12 @@ func (om *OktaManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
||||
// wt_account_id and wt_pending_invite.
|
||||
func updateUserProfileSchema(client *okta.Client) error {
|
||||
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
|
||||
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
|
||||
|
||||
_, resp, err := client.UserSchema.UpdateUserProfile(
|
||||
context.Background(),
|
||||
"default",
|
||||
okta.UserSchema{
|
||||
Definitions: &okta.UserSchemaDefinitions{
|
||||
Custom: &okta.UserSchemaPublic{
|
||||
Id: "#custom",
|
||||
Type: "object",
|
||||
Properties: map[string]*okta.UserSchemaAttribute{
|
||||
wtAccountID: {
|
||||
MaxLength: 100,
|
||||
MinLength: 1,
|
||||
Required: new(bool),
|
||||
Scope: "NONE",
|
||||
Title: "Wt Account Id",
|
||||
Type: "string",
|
||||
Permissions: userPermissions,
|
||||
},
|
||||
wtPendingInvite: {
|
||||
Required: new(bool),
|
||||
Scope: "NONE",
|
||||
Title: "Wt Pending Invite",
|
||||
Type: "boolean",
|
||||
Permissions: userPermissions,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unable to update user profile schema, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOktaUserToUserData parse okta user to UserData.
|
||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
var oktaUser struct {
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
@ -418,9 +284,5 @@ func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
Email: oktaUser.Email,
|
||||
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
|
||||
ID: user.Id,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: oktaUser.AccountID,
|
||||
WTPendingInvite: &oktaUser.PendingInvite,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -1,31 +1,28 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseOktaUser(t *testing.T) {
|
||||
type parseOktaUserTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile *okta.User
|
||||
expectedUserData *UserData
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
}
|
||||
|
||||
parseOktaTestCase1 := parseOktaUserTest{
|
||||
name: "Good Request",
|
||||
invite: true,
|
||||
name: "Good Request",
|
||||
inputProfile: &okta.User{
|
||||
Id: "123",
|
||||
Profile: &okta.UserProfile{
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"wt_account_id": "456",
|
||||
"wt_pending_invite": true,
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
},
|
||||
},
|
||||
expectedUserData: &UserData{
|
||||
@ -41,36 +38,17 @@ func TestParseOktaUser(t *testing.T) {
|
||||
|
||||
parseOktaTestCase2 := parseOktaUserTest{
|
||||
name: "Invalid okta user",
|
||||
invite: true,
|
||||
inputProfile: nil,
|
||||
expectedUserData: nil,
|
||||
assertErrFunc: assert.Error,
|
||||
}
|
||||
|
||||
parseOktaTestCase3 := parseOktaUserTest{
|
||||
name: "Invalid pending invite type",
|
||||
invite: false,
|
||||
inputProfile: &okta.User{
|
||||
Id: "123",
|
||||
Profile: &okta.UserProfile{
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"wt_account_id": "456",
|
||||
"wt_pending_invite": "true",
|
||||
},
|
||||
},
|
||||
expectedUserData: nil,
|
||||
assertErrFunc: assert.Error,
|
||||
}
|
||||
|
||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2, parseOktaTestCase3} {
|
||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
userData, err := parseOktaUser(testCase.inputProfile)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
|
||||
|
||||
if err == nil {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
|
||||
}
|
||||
})
|
||||
@ -83,13 +61,5 @@ func userDataEqual(a, b *UserData) bool {
|
||||
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
if a.AppMetadata.WTAccountID != b.AppMetadata.WTAccountID {
|
||||
return false
|
||||
}
|
||||
|
||||
if a.AppMetadata.WTPendingInvite != nil && b.AppMetadata.WTPendingInvite != nil &&
|
||||
*a.AppMetadata.WTPendingInvite != *b.AppMetadata.WTPendingInvite {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -1,13 +1,10 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -68,12 +65,6 @@ type zitadelUser struct {
|
||||
|
||||
type zitadelAttributes map[string][]map[string]any
|
||||
|
||||
// zitadelMetadata holds additional user data.
|
||||
type zitadelMetadata struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// zitadelProfile represents an zitadel user profile response.
|
||||
type zitadelProfile struct {
|
||||
ID string `json:"id"`
|
||||
@ -82,7 +73,6 @@ type zitadelProfile struct {
|
||||
PreferredLoginName string `json:"preferredLoginName"`
|
||||
LoginNames []string `json:"loginNames"`
|
||||
Human *zitadelUser `json:"human"`
|
||||
Metadata []zitadelMetadata
|
||||
}
|
||||
|
||||
// NewZitadelManager creates a new instance of the ZitadelManager.
|
||||
@ -235,42 +225,8 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite.
|
||||
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
payload, err := buildZitadelCreateUserRequestPayload(email, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/human/_import", payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
var result struct {
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
err = zm.helper.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invite := true
|
||||
appMetadata := AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
}
|
||||
|
||||
// Add metadata to new user
|
||||
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return zm.GetUserDataByID(result.UserID, appMetadata)
|
||||
func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@ -308,12 +264,6 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
metadata, err := zm.getUserMetadata(profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.Metadata = metadata
|
||||
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
@ -337,18 +287,15 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metadata, err := zm.getUserMetadata(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.User.Metadata = metadata
|
||||
userData := profile.User.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return profile.User.userData(), nil
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
accounts, err := zm.GetAllAccounts()
|
||||
body, err := zm.post("users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -357,7 +304,21 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
return accounts[accountID], nil
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
@ -380,22 +341,8 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles.Result {
|
||||
// fetch user metadata
|
||||
metadata, err := zm.getUserMetadata(profile.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.Metadata = metadata
|
||||
|
||||
userData := profile.userData()
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
@ -403,42 +350,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
if appMetadata.WTPendingInvite == nil {
|
||||
appMetadata.WTPendingInvite = new(bool)
|
||||
}
|
||||
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
|
||||
|
||||
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
|
||||
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
|
||||
|
||||
metadata := zitadelAttributes{
|
||||
"metadata": {
|
||||
{
|
||||
"key": wtAccountID,
|
||||
"value": wtAccountIDValue,
|
||||
},
|
||||
{
|
||||
"key": wtPendingInvite,
|
||||
"value": wtPendingInviteValue,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, err := zm.helper.Marshal(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
|
||||
_, err = zm.post(resource, string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -460,24 +372,6 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// getUserMetadata requests user metadata from zitadel via ID.
|
||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||
body, err := zm.post(resource, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata struct{ Result []zitadelMetadata }
|
||||
err = zm.helper.Unmarshal(body, &metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return metadata.Result, nil
|
||||
}
|
||||
|
||||
// post perform Post requests.
|
||||
@ -517,38 +411,7 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// delete perform Delete requests.
|
||||
func (zm *ZitadelManager) delete(resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := zm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (zm *ZitadelManager) delete(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -588,38 +451,13 @@ func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// value returns string represented by the base64 string value.
|
||||
func (zm zitadelMetadata) value() string {
|
||||
value, err := base64.StdEncoding.DecodeString(zm.Value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(value)
|
||||
}
|
||||
|
||||
// userData construct user data from zitadel profile.
|
||||
func (zp zitadelProfile) userData() *UserData {
|
||||
var (
|
||||
email string
|
||||
name string
|
||||
wtAccountIDValue string
|
||||
wtPendingInviteValue bool
|
||||
email string
|
||||
name string
|
||||
)
|
||||
|
||||
for _, metadata := range zp.Metadata {
|
||||
if metadata.Key == wtAccountID {
|
||||
wtAccountIDValue = metadata.value()
|
||||
}
|
||||
|
||||
if metadata.Key == wtPendingInvite {
|
||||
value, err := strconv.ParseBool(metadata.value())
|
||||
if err == nil {
|
||||
wtPendingInviteValue = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Obtain the email for the human account and the login name,
|
||||
// for the machine account.
|
||||
if zp.Human != nil {
|
||||
@ -636,39 +474,5 @@ func (zp zitadelProfile) userData() *UserData {
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: zp.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: wtAccountIDValue,
|
||||
WTPendingInvite: &wtPendingInviteValue,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
|
||||
var firstName, lastName string
|
||||
|
||||
words := strings.Fields(name)
|
||||
if n := len(words); n > 0 {
|
||||
firstName = strings.Join(words[:n-1], " ")
|
||||
lastName = words[n-1]
|
||||
}
|
||||
|
||||
req := &zitadelUser{
|
||||
UserName: name,
|
||||
Profile: zitadelUserInfo{
|
||||
FirstName: strings.TrimSpace(firstName),
|
||||
LastName: strings.TrimSpace(lastName),
|
||||
DisplayName: name,
|
||||
},
|
||||
Email: zitadelEmail{
|
||||
Email: email,
|
||||
IsEmailVerified: false,
|
||||
},
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
@ -7,9 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewZitadelManager(t *testing.T) {
|
||||
@ -63,15 +64,6 @@ func TestNewZitadelManager(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockZitadelCredentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func TestZitadelRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
@ -296,98 +288,6 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockZitadelCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
reqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
httpClient: &reqClient,
|
||||
credentials: testCase.managerCreds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZitadelProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
@ -418,16 +318,6 @@ func TestZitadelProfile(t *testing.T) {
|
||||
IsEmailVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []zitadelMetadata{
|
||||
{
|
||||
Key: "wt_account_id",
|
||||
Value: "MQ==",
|
||||
},
|
||||
{
|
||||
Key: "wt_pending_invite",
|
||||
Value: "ZmFsc2U=",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test1",
|
||||
@ -451,16 +341,6 @@ func TestZitadelProfile(t *testing.T) {
|
||||
"machine",
|
||||
},
|
||||
Human: nil,
|
||||
Metadata: []zitadelMetadata{
|
||||
{
|
||||
Key: "wt_account_id",
|
||||
Value: "MQ==",
|
||||
},
|
||||
{
|
||||
Key: "wt_pending_invite",
|
||||
Value: "dHJ1ZQ==",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test2",
|
||||
@ -480,8 +360,6 @@ func TestZitadelProfile(t *testing.T) {
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user