mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 14:03:35 +01:00
Fix(auth0) caching Users by accountId
This commit is contained in:
parent
1e444f58c1
commit
c86c620016
@ -328,7 +328,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
|
||||
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
queriedUsers, err = am.idpManager.GetBatchedUserData(accountID)
|
||||
queriedUsers, err = am.idpManager.GetAllUsers(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ type Auth0Credentials struct {
|
||||
}
|
||||
|
||||
type Auth0Profile struct {
|
||||
AccountId string `json:"wt_account_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
@ -223,59 +224,33 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
func batchRequestUsersUrl(authIssuer, accountId string, page int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountId)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), q, nil
|
||||
}
|
||||
|
||||
func requestByUserIdUrl(authIssuer, userId string) string {
|
||||
return authIssuer + "/api/v2/users/" + userId
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
// Gets all users from cache, if the cache exists
|
||||
// Otherwise we will initialize the cache with creating the export job on auth0
|
||||
func (am *Auth0Manager) GetAllUsers(accountId string) ([]*UserData, error) {
|
||||
if len(am.cachedUsersByAccountId[accountId]) == 0 {
|
||||
err := am.createExportUsersJob(accountId)
|
||||
if err != nil {
|
||||
log.Errorf("error while closing body for url %s: %v", url, err)
|
||||
log.Debugf("Couldn't cache users; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}()
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var list []*UserData
|
||||
|
||||
cachedUsers := am.cachedUsersByAccountId[accountId]
|
||||
for _, val := range cachedUsers {
|
||||
list = append(list, &UserData{
|
||||
Name: val.Name,
|
||||
Email: val.Email,
|
||||
ID: val.UserID,
|
||||
})
|
||||
}
|
||||
return body, nil
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// This creates an export job on auth0 for all users.
|
||||
func (am *Auth0Manager) CreateExportUsersJob(accountId string) error {
|
||||
func (am *Auth0Manager) createExportUsersJob(accountId string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -283,7 +258,8 @@ func (am *Auth0Manager) CreateExportUsersJob(accountId string) error {
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/jobs/users-exports"
|
||||
|
||||
payloadString := fmt.Sprintf("{\"format\": \"json\"}")
|
||||
payloadString := fmt.Sprintf("{\"format\": \"json\"," +
|
||||
"\"fields\": [{\"name\": \"created_at\"}, {\"name\": \"last_login\"},{\"name\": \"user_id\"}, {\"name\": \"email\"}, {\"name\": \"name\"}, {\"name\": \"app_metadata.wt_account_id\", \"export_as\": \"wt_account_id\"}]}")
|
||||
|
||||
payload := strings.NewReader(payloadString)
|
||||
|
||||
@ -340,7 +316,7 @@ func (am *Auth0Manager) CreateExportUsersJob(accountId string) error {
|
||||
}
|
||||
|
||||
if done {
|
||||
err = am.cacheUsers(accountId, downloadLink)
|
||||
err = am.cacheUsers(downloadLink)
|
||||
if err != nil {
|
||||
log.Debugf("Failed to cache users via download link; %v", err)
|
||||
}
|
||||
@ -350,8 +326,8 @@ func (am *Auth0Manager) CreateExportUsersJob(accountId string) error {
|
||||
}
|
||||
|
||||
// Downloads the users from auth0 and caches it in memory
|
||||
// We don't need
|
||||
func (am *Auth0Manager) cacheUsers(accountId, location string) error {
|
||||
// Users are only cached if they have an wt_account_id stored in auth0
|
||||
func (am *Auth0Manager) cacheUsers(location string) error {
|
||||
body, err := doGetReq(am.httpClient, location, "")
|
||||
if err != nil {
|
||||
log.Debugf("Can't download cached users; %v", err)
|
||||
@ -374,8 +350,9 @@ func (am *Auth0Manager) cacheUsers(accountId, location string) error {
|
||||
log.Errorf("Couldn't decode profile; %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
am.cachedUsersByAccountId[accountId] = append(am.cachedUsersByAccountId[accountId], profile)
|
||||
if profile.AccountId != "" {
|
||||
am.cachedUsersByAccountId[profile.AccountId] = append(am.cachedUsersByAccountId[profile.AccountId], profile)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -388,6 +365,7 @@ func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobId string)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
@ -407,6 +385,8 @@ func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobId string)
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.Debugf("Current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
}
|
||||
@ -416,100 +396,71 @@ func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobId string)
|
||||
}
|
||||
}
|
||||
|
||||
// This recaches every use from account
|
||||
func (am *Auth0Manager) ForceUpdateUserCache(accountId string) {
|
||||
// Invalidates old cache for Account and re-queries it from auth0
|
||||
func (am *Auth0Manager) forceUpdateUserCache(accountId string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
var list []Auth0Profile
|
||||
|
||||
func (am *Auth0Manager) GetBatchedUserData(accountId string) ([]*UserData, error) {
|
||||
// first time calling this
|
||||
// we need to check whether we need to call for users we don't have
|
||||
if len(am.cachedUsersByAccountId[accountId]) == 0 {
|
||||
err := am.CreateExportUsersJob(accountId)
|
||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// auth0 limitation of 1000 users via this endpoint
|
||||
for page := 0; page < 20; page++ {
|
||||
reqURL, query, err := batchRequestUsersUrl(am.authIssuer, accountId, page)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't cache users; %v", err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var batch []Auth0Profile
|
||||
err = json.Unmarshal(body, &batch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("requested batch; %v", batch)
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for user := range batch {
|
||||
list = append(list, batch[user])
|
||||
}
|
||||
}
|
||||
am.cachedUsersByAccountId[accountId] = list
|
||||
|
||||
var list []*UserData
|
||||
|
||||
cachedUsers := am.cachedUsersByAccountId[accountId]
|
||||
for _, val := range cachedUsers {
|
||||
list = append(list, &UserData{
|
||||
Name: val.Name,
|
||||
Email: val.Email,
|
||||
ID: val.UserID,
|
||||
})
|
||||
}
|
||||
|
||||
return list, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBatchedUserData requests users in batches from Auth0
|
||||
// func (am *Auth0Manager) GetBatchedUserData(accountId string) ([]*UserData, error) {
|
||||
// jwtToken, err := am.credentials.Authenticate()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var list []*UserData
|
||||
|
||||
// // https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// // auth0 limitation of 1000 users via this endpoint
|
||||
// for page := 0; page < 20; page++ {
|
||||
// reqURL, query, err := batchRequestUsersUrl(am.authIssuer, accountId, page)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
// req.Header.Add("content-type", "application/json")
|
||||
|
||||
// res, err := am.httpClient.Do(req)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// body, err := io.ReadAll(res.Body)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var batch []UserData
|
||||
// err = json.Unmarshal(body, &batch)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// log.Debugf("requested batch; %v", batch)
|
||||
|
||||
// err = res.Body.Close()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// if res.StatusCode != 200 {
|
||||
// return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
|
||||
// }
|
||||
|
||||
// if len(batch) == 0 {
|
||||
// return list, nil
|
||||
// }
|
||||
|
||||
// for user := range batch {
|
||||
// list = append(list, &batch[user])
|
||||
// }
|
||||
// }
|
||||
|
||||
// return list, nil
|
||||
// }
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
@ -601,3 +552,54 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userId string, appMetadata AppMeta
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func batchRequestUsersUrl(authIssuer, accountId string, page int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountId)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), q, nil
|
||||
}
|
||||
|
||||
func requestByUserIdUrl(authIssuer, userId string) string {
|
||||
return authIssuer + "/api/v2/users/" + userId
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
@ -11,8 +11,7 @@ import (
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetBatchedUserData(accountId string) ([]*UserData, error)
|
||||
CreateExportUsersJob(accountId string) error
|
||||
GetAllUsers(accountId string) ([]*UserData, error)
|
||||
}
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
|
Loading…
Reference in New Issue
Block a user