Add Pagination for IdP Users Fetch (#1210)

* Retrieve all workspace users via pagination, excluding custom user attributes

* Retrieve all authentik users via pagination

* Retrieve all Azure AD users via pagination

* Simplify user data appending operation

Reduced unnecessary iteration and used an efficient way to append all users to 'indexedUsers'

* Fix ineffectual assignment to reqURL

* Retrieve all Okta users via pagination

* Add missing GetAccount metrics

* Refactor

* minimize memory allocation

Refactored the memory allocation for the 'users' slice in the Okta IDP code. Previously, the slice was only initialized but not given a size. Now the size of userList is utilized to optimize memory allocation, reducing potential slice resizing and memory re-allocation costs while appending users.

* Add logging for entries received from IdP management

Added informative and debug logging statements in account.go file. Logging has been added to identify the number of entries received from Identity Provider (IdP) management. This will aid in tracking and debugging any potential data ingestion issues.
This commit is contained in:
Bethuel Mmbaga 2023-10-11 17:09:30 +03:00 committed by GitHub
parent 3c485dc7a1
commit 4ad14cb46b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 122 deletions

View File

@ -946,6 +946,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
if err != nil { if err != nil {
return err return err
} }
log.Infof("%d entries received from IdP management", len(userData))
// If the Identity Provider does not support writing AppMetadata, // If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field. // in cases like this, we expect it to return all users in an "unset" field.
@ -1045,6 +1046,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("%d entries received from IdP management", len(userData))
dataMap := make(map[string]*idp.UserData, len(userData)) dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData { for _, datum := range userData {

View File

@ -251,34 +251,18 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
ctx, err := am.authenticationContext() users, err := am.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount() am.appMetrics.IDPMetrics().CountGetAccount()
} }
if resp.StatusCode != http.StatusOK { for index, user := range users {
if am.appMetrics != nil { user.AppMetadata.WTAccountID = accountID
am.appMetrics.IDPMetrics().CountRequestStatusError() users[index] = user
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData := parseAuthentikUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -287,20 +271,37 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
page := int32(1)
for {
ctx, err := am.authenticationContext() ctx, err := am.authenticationContext()
if err != nil { if err != nil {
return nil, err return nil, err
} }
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() _ = resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil { if am.appMetrics != nil {
@ -309,13 +310,18 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
} }
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results { for _, user := range userList.Results {
userData := parseAuthentikUser(user) users = append(users, parseAuthentikUser(user))
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
} }
return indexedUsers, nil page = int32(userList.GetPagination().Next)
if userList.GetPagination().Next == 0 {
break
}
}
return users, nil
} }
// CreateUser creates a new user in authentik Idp and sends an invitation. // CreateUser creates a new user in authentik Idp and sends an invitation.

View File

@ -266,10 +266,7 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
q := url.Values{} users, err := am.getAllUsers()
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -278,18 +275,9 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
am.appMetrics.IDPMetrics().CountGetAccount() am.appMetrics.IDPMetrics().CountGetAccount()
} }
var profiles struct{ Value []azureProfile } for index, user := range users {
err = am.helper.Unmarshal(body, &profiles) user.AppMetadata.WTAccountID = accountID
if err != nil { users[index] = user
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Value {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -298,28 +286,16 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
q := url.Values{} users, err := am.getAllUsers()
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
} }
return indexedUsers, nil return indexedUsers, nil
@ -373,6 +349,39 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil return nil
} }
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
q.Add("$select", profileFields)
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(nextLink, q)
if err != nil {
return nil, err
}
var profiles struct {
Value []azureProfile
NextLink string `json:"@odata.nextLink"`
}
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
for _, profile := range profiles.Value {
users = append(users, profile.userData())
}
nextLink = profiles.NextLink
}
return users, nil
}
// get perform Get requests. // get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
@ -380,7 +389,14 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
return nil, err return nil, err
} }
reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode()) var reqURL string
if strings.HasPrefix(resource, "https") {
// Already an absolute URL for paging
reqURL = resource
} else {
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
}
req, err := http.NewRequest(http.MethodGet, reqURL, nil) req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -96,7 +96,7 @@ func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata)
// GetUserDataByID requests user data from Google Workspace via ID. // GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Projection("full").Do() user, err := gm.usersService.Get(userID).Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,43 +113,69 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() users, err := gm.getAllUsers()
if err != nil {
return nil, err
}
usersData := make([]*UserData, 0)
for _, user := range usersList.Users {
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata.WTAccountID = accountID
usersData = append(usersData, userData)
}
return usersData, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if gm.appMetrics != nil { if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts() gm.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
} }
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
userData := parseGoogleWorkspaceUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
} }
return indexedUsers, nil return indexedUsers, nil
} }
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
pageToken := ""
for {
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
if pageToken != "" {
call.PageToken(pageToken)
}
resp, err := call.Do()
if err != nil {
return nil, err
}
for _, user := range resp.Users {
users = append(users, parseGoogleWorkspaceUser(user))
}
pageToken = resp.NextPageToken
if pageToken == "" {
break
}
}
return users, nil
}
// CreateUser creates a new user in Google Workspace and sends an invitation. // CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
@ -158,7 +184,7 @@ func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, erro
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Projection("full").Do() user, err := gm.usersService.Get(email).Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/okta/okta-sdk-golang/v2/okta" "github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
@ -160,7 +161,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
users, resp, err := om.client.User.ListUsers(context.Background(), nil) users, err := om.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -169,39 +170,40 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
om.appMetrics.IDPMetrics().CountGetAccount() om.appMetrics.IDPMetrics().CountGetAccount()
} }
if resp.StatusCode != http.StatusOK { for index, user := range users {
if om.appMetrics != nil { user.AppMetadata.WTAccountID = accountID
om.appMetrics.IDPMetrics().CountRequestStatusError() users[index] = user
}
return nil, fmt.Errorf("unable to get account, statusCode %d", resp.StatusCode)
} }
list := make([]*UserData, 0) return users, nil
for _, user := range users {
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata.WTAccountID = accountID
list = append(list, userData)
}
return list, nil
} }
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
users, resp, err := om.client.User.ListUsers(context.Background(), nil) users, err := om.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if om.appMetrics != nil { if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetAllAccounts() om.appMetrics.IDPMetrics().CountGetAllAccounts()
} }
return indexedUsers, nil
}
// getAllUsers returns all users in an Okta account.
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
qp := query.NewQueryParams(query.WithLimit(200))
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil { if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError() om.appMetrics.IDPMetrics().CountRequestStatusError()
@ -209,17 +211,34 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
} }
indexedUsers := make(map[string][]*UserData) for resp.HasNextPage() {
for _, user := range users { paginatedUsers := make([]*okta.User, 0)
resp, err = resp.Next(context.Background(), &paginatedUsers)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
userList = append(userList, paginatedUsers...)
}
users := make([]*UserData, 0, len(userList))
for _, user := range userList {
userData, err := parseOktaUser(user) userData, err := parseOktaUser(user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) users = append(users, userData)
} }
return indexedUsers, nil return users, nil
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.