Add Azure Idp Manager (#822)

Added intergration with Azure IDP user API.

Use the steps in azure-ad.md for configuration:
cb03373f8f/docs/integrations/identity-providers/self-hosted/azure-ad.md
This commit is contained in:
Bethuel 2023-05-03 15:51:44 +03:00 committed by GitHub
parent ecac82a5ae
commit 873b56f856
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 994 additions and 0 deletions

View File

@ -0,0 +1,662 @@
package idp
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
)
const (
// azure extension properties template
wtAccountIDTpl = "extension_%s_wt_account_id"
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
profileFields = "id,displayName,mail,userPrincipalName"
extensionFields = "id,name,targetObjects"
)
// AzureManager azure manager client instance.
type AzureManager struct {
ClientID string
ObjectID string
GraphAPIEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// AzureClientConfig azure manager client configurations.
type AzureClientConfig struct {
ClientID string
ClientSecret string
GraphAPIEndpoint string
ObjectID string
TokenEndpoint string
GrantType string
}
// AzureCredentials azure authentication information.
type AzureCredentials struct {
clientConfig AzureClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// azureProfile represents an azure user profile.
type azureProfile map[string]any
// passwordProfile represent authentication method for,
// newly created user profile.
type passwordProfile struct {
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
Password string `json:"password"`
}
// azureExtension represent custom attribute,
// that can be added to user objects in Azure Active Directory (AD).
type azureExtension struct {
Name string `json:"name"`
DataType string `json:"dataType"`
TargetObjects []string `json:"targetObjects"`
}
// NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.GraphAPIEndpoint == "" || config.TokenEndpoint == "" {
return nil, fmt.Errorf("azure idp configuration is not complete")
}
if config.GrantType != "client_credentials" {
return nil, fmt.Errorf("azure idp configuration failed. Grant Type should be client_credentials")
}
credentials := &AzureCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
manager := &AzureManager{
ObjectID: config.ObjectID,
ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}
err := manager.configureAppMetadata()
if err != nil {
return nil, err
}
return manager, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
func (ac *AzureCredentials) jwtStillValid() bool {
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
data.Set("grant_type", ac.clientConfig.GrantType)
data.Set("scope", "https://graph.microsoft.com/.default")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = ac.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = ac.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if ac.jwtStillValid() {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken()
if err != nil {
return ac.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
if err != nil {
return ac.jwtToken, err
}
ac.jwtToken = jwtToken
return ac.jwtToken, nil
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
if err != nil {
return nil, err
}
body, err := am.post("users", payload)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
profile[wtAccountIDField] = accountID
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
profile[wtPendingInviteField] = true
return profile.userData(am.ClientID), nil
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users/"+userID, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
return profile.userData(am.ClientID), nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users/"+email, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, profile.userData(am.ClientID))
return users, nil
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Value {
users = append(users, profile.userData(am.ClientID))
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
body, err := am.get("users", q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
var profiles struct{ Value []azureProfile }
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value {
userData := profile.userData(am.ClientID)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
data, err := am.helper.Marshal(map[string]any{
wtAccountIDField: appMetadata.WTAccountID,
wtPendingInviteField: appMetadata.WTPendingInvite,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating idp metadata for user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.get(resource, q)
if err != nil {
return nil, err
}
var extensions struct{ Value []azureExtension }
err = am.helper.Unmarshal(body, &extensions)
if err != nil {
return nil, err
}
return extensions.Value, nil
}
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
extension := azureExtension{
Name: name,
DataType: "string",
TargetObjects: []string{"User"},
}
payload, err := am.helper.Marshal(extension)
if err != nil {
return nil, err
}
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.post(resource, string(payload))
if err != nil {
return nil, err
}
var userExtension azureExtension
err = am.helper.Unmarshal(body, &userExtension)
if err != nil {
return nil, err
}
return &userExtension, nil
}
// configureAppMetadata sets up app metadata extensions if they do not exists.
func (am *AzureManager) configureAppMetadata() error {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
extensions, err := am.getUserExtensions()
if err != nil {
return err
}
// If the wt_account_id extension does not already exist, create it.
if !hasExtension(extensions, wtAccountIDField) {
_, err = am.createUserExtension(wtAccountID)
if err != nil {
return err
}
}
// If the wt_pending_invite extension does not already exist, create it.
if !hasExtension(extensions, wtPendingInviteField) {
_, err = am.createUserExtension(wtPendingInvite)
if err != nil {
return err
}
}
return nil
}
// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// post perform Post requests.
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile.
func (ap azureProfile) userData(clientID string) *UserData {
id, ok := ap["id"].(string)
if !ok {
id = ""
}
email, ok := ap["userPrincipalName"].(string)
if !ok {
email = ""
}
name, ok := ap["displayName"].(string)
if !ok {
name = ""
}
accountIDField := extensionName(wtAccountIDTpl, clientID)
accountID, ok := ap[accountIDField].(string)
if !ok {
accountID = ""
}
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
pendingInvite, ok := ap[pendingInviteField].(bool)
if !ok {
pendingInvite = false
}
return &UserData{
Email: email,
Name: name,
ID: id,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
}
}
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
req := &azureProfile{
"accountEnabled": true,
"displayName": name,
"mailNickName": strings.Join(strings.Split(name, " "), ""),
"userPrincipalName": email,
"passwordProfile": passwordProfile{
ForceChangePasswordNextSignIn: true,
Password: GeneratePassword(8, 1, 1, 1),
},
wtAccountIDField: accountID,
wtPendingInviteField: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func extensionName(extensionTpl, clientID string) string {
clientID = strings.ReplaceAll(clientID, "-", "")
return fmt.Sprintf(extensionTpl, clientID)
}
// hasExtension checks whether a given extension by name,
// exists in an list of extensions.
func hasExtension(extensions []azureExtension, name string) bool {
for _, ext := range extensions {
if ext.Name == name {
return true
}
}
return false
}

View File

@ -0,0 +1,329 @@
package idp
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type mockAzureCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestAzureAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestAzureUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &AzureManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestAzureProfile(t *testing.T) {
type azureProfileTest struct {
name string
clientID string
invite bool
inputProfile azureProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "Good Request",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test1",
"displayName": "John Doe",
"userPrincipalName": "test1@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
Email: "test1@test.com",
Name: "John Doe",
ID: "test1",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: true,
inputProfile: azureProfile{
"displayName": "John Doe",
"userPrincipalName": "test2@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
},
expectedUserData: UserData{
Email: "test2@test.com",
Name: "John Doe",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test3",
"userPrincipalName": "test3@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
ID: "test3",
Email: "test3@test.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase4 := azureProfileTest{
name: "Missing Extension Fields",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test4",
"displayName": "John Doe",
"userPrincipalName": "test4@test.com",
},
expectedUserData: UserData{
ID: "test4",
Name: "John Doe",
Email: "test4@test.com",
AppMetadata: AppMetadata{},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData(testCase.clientID)
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}

View File

@ -24,6 +24,7 @@ type Config struct {
ManagerType string
Auth0ClientCredentials Auth0ClientConfig
KeycloakClientCredentials KeycloakClientConfig
AzureClientCredentials AzureClientConfig
}
// ManagerCredentials interface that authenticates using the credential of each type of idp
@ -73,6 +74,8 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
return nil, nil
case "auth0":
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
case "azure":
return NewAzureManager(config.AzureClientCredentials, appMetrics)
case "keycloak":
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
default: