mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 22:10:56 +01:00
459 lines
11 KiB
Go
459 lines
11 KiB
Go
package idp
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
)
|
|
|
|
const profileFields = "id,displayName,mail,userPrincipalName"
|
|
|
|
// AzureManager azure manager client instance.
|
|
type AzureManager struct {
|
|
ClientID string
|
|
ObjectID string
|
|
GraphAPIEndpoint string
|
|
httpClient ManagerHTTPClient
|
|
credentials ManagerCredentials
|
|
helper ManagerHelper
|
|
appMetrics telemetry.AppMetrics
|
|
}
|
|
|
|
// AzureClientConfig azure manager client configurations.
|
|
type AzureClientConfig struct {
|
|
ClientID string
|
|
ClientSecret string
|
|
ObjectID string
|
|
GraphAPIEndpoint string
|
|
TokenEndpoint string
|
|
GrantType string
|
|
}
|
|
|
|
// AzureCredentials azure authentication information.
|
|
type AzureCredentials struct {
|
|
clientConfig AzureClientConfig
|
|
helper ManagerHelper
|
|
httpClient ManagerHTTPClient
|
|
jwtToken JWTToken
|
|
mux sync.Mutex
|
|
appMetrics telemetry.AppMetrics
|
|
}
|
|
|
|
// azureProfile represents an azure user profile.
|
|
type azureProfile map[string]any
|
|
|
|
// NewAzureManager creates a new instance of the AzureManager.
|
|
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
httpTransport.MaxIdleConns = 5
|
|
|
|
httpClient := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
Transport: httpTransport,
|
|
}
|
|
helper := JsonParser{}
|
|
|
|
if config.ClientID == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
|
}
|
|
|
|
if config.ClientSecret == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
|
}
|
|
|
|
if config.TokenEndpoint == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
|
|
}
|
|
|
|
if config.GraphAPIEndpoint == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
|
|
}
|
|
|
|
if config.ObjectID == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
|
}
|
|
|
|
if config.GrantType == "" {
|
|
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
|
|
}
|
|
|
|
credentials := &AzureCredentials{
|
|
clientConfig: config,
|
|
httpClient: httpClient,
|
|
helper: helper,
|
|
appMetrics: appMetrics,
|
|
}
|
|
|
|
return &AzureManager{
|
|
ObjectID: config.ObjectID,
|
|
ClientID: config.ClientID,
|
|
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
|
httpClient: httpClient,
|
|
credentials: credentials,
|
|
helper: helper,
|
|
appMetrics: appMetrics,
|
|
}, nil
|
|
}
|
|
|
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
|
func (ac *AzureCredentials) jwtStillValid() bool {
|
|
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
|
}
|
|
|
|
// requestJWTToken performs request to get jwt token.
|
|
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
|
data := url.Values{}
|
|
data.Set("client_id", ac.clientConfig.ClientID)
|
|
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
|
data.Set("grant_type", ac.clientConfig.GrantType)
|
|
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// get base url and add "/.default" as scope
|
|
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
|
scopeURL := baseURL + "/.default"
|
|
data.Set("scope", scopeURL)
|
|
|
|
payload := strings.NewReader(data.Encode())
|
|
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
|
|
|
log.Debug("requesting new jwt token for azure idp manager")
|
|
|
|
resp, err := ac.httpClient.Do(req)
|
|
if err != nil {
|
|
if ac.appMetrics != nil {
|
|
ac.appMetrics.IDPMetrics().CountRequestError()
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
|
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
|
jwtToken := JWTToken{}
|
|
body, err := io.ReadAll(rawBody)
|
|
if err != nil {
|
|
return jwtToken, err
|
|
}
|
|
|
|
err = ac.helper.Unmarshal(body, &jwtToken)
|
|
if err != nil {
|
|
return jwtToken, err
|
|
}
|
|
|
|
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
|
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
|
}
|
|
|
|
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
|
if err != nil {
|
|
return jwtToken, err
|
|
}
|
|
|
|
// Exp maps into exp from jwt token
|
|
var IssuedAt struct{ Exp int64 }
|
|
err = ac.helper.Unmarshal(data, &IssuedAt)
|
|
if err != nil {
|
|
return jwtToken, err
|
|
}
|
|
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
|
|
|
return jwtToken, nil
|
|
}
|
|
|
|
// Authenticate retrieves access token to use the azure Management API.
|
|
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
|
ac.mux.Lock()
|
|
defer ac.mux.Unlock()
|
|
|
|
if ac.appMetrics != nil {
|
|
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
|
}
|
|
|
|
// reuse the token without requesting a new one if it is not expired,
|
|
// and if expiry time is sufficient time available to make a request.
|
|
if ac.jwtStillValid() {
|
|
return ac.jwtToken, nil
|
|
}
|
|
|
|
resp, err := ac.requestJWTToken()
|
|
if err != nil {
|
|
return ac.jwtToken, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
|
if err != nil {
|
|
return ac.jwtToken, err
|
|
}
|
|
|
|
ac.jwtToken = jwtToken
|
|
|
|
return ac.jwtToken, nil
|
|
}
|
|
|
|
// CreateUser creates a new user in azure AD Idp.
|
|
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
|
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
}
|
|
|
|
// GetUserDataByID requests user data from keycloak via ID.
|
|
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
q := url.Values{}
|
|
q.Add("$select", profileFields)
|
|
|
|
body, err := am.get("users/"+userID, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
}
|
|
|
|
var profile azureProfile
|
|
err = am.helper.Unmarshal(body, &profile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userData := profile.userData()
|
|
userData.AppMetadata = appMetadata
|
|
|
|
return userData, nil
|
|
}
|
|
|
|
// GetUserByEmail searches users with a given email.
|
|
// If no users have been found, this function returns an empty list.
|
|
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
|
q := url.Values{}
|
|
q.Add("$select", profileFields)
|
|
|
|
body, err := am.get("users/"+email, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
}
|
|
|
|
var profile azureProfile
|
|
err = am.helper.Unmarshal(body, &profile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
users := make([]*UserData, 0)
|
|
users = append(users, profile.userData())
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// GetAccount returns all the users for a given profile.
|
|
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
|
users, err := am.getAllUsers()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountGetAccount()
|
|
}
|
|
|
|
for index, user := range users {
|
|
user.AppMetadata.WTAccountID = accountID
|
|
users[index] = user
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
// It returns a list of users indexed by accountID.
|
|
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
|
users, err := am.getAllUsers()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
indexedUsers := make(map[string][]*UserData)
|
|
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
|
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
}
|
|
|
|
return indexedUsers, nil
|
|
}
|
|
|
|
// UpdateUserAppMetadata updates user app metadata based on userID.
|
|
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
|
return nil
|
|
}
|
|
|
|
// InviteUserByID resend invitations to users who haven't activated,
|
|
// their accounts prior to the expiration period.
|
|
func (am *AzureManager) InviteUserByID(_ string) error {
|
|
return fmt.Errorf("method InviteUserByID not implemented")
|
|
}
|
|
|
|
// DeleteUser from Azure.
|
|
func (am *AzureManager) DeleteUser(userID string) error {
|
|
jwtToken, err := am.credentials.Authenticate()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
|
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
req.Header.Add("content-type", "application/json")
|
|
|
|
log.Debugf("delete idp user %s", userID)
|
|
|
|
resp, err := am.httpClient.Do(req)
|
|
if err != nil {
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountRequestError()
|
|
}
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountDeleteUser()
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getAllUsers returns all users in an Azure AD account.
|
|
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
|
users := make([]*UserData, 0)
|
|
|
|
q := url.Values{}
|
|
q.Add("$select", profileFields)
|
|
q.Add("$top", "500")
|
|
|
|
for nextLink := "users"; nextLink != ""; {
|
|
body, err := am.get(nextLink, q)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var profiles struct {
|
|
Value []azureProfile
|
|
NextLink string `json:"@odata.nextLink"`
|
|
}
|
|
err = am.helper.Unmarshal(body, &profiles)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, profile := range profiles.Value {
|
|
users = append(users, profile.userData())
|
|
}
|
|
|
|
nextLink = profiles.NextLink
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// get perform Get requests.
|
|
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
|
jwtToken, err := am.credentials.Authenticate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var reqURL string
|
|
if strings.HasPrefix(resource, "https") {
|
|
// Already an absolute URL for paging
|
|
reqURL = resource
|
|
} else {
|
|
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
req.Header.Add("content-type", "application/json")
|
|
|
|
resp, err := am.httpClient.Do(req)
|
|
if err != nil {
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountRequestError()
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if am.appMetrics != nil {
|
|
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
}
|
|
|
|
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
|
}
|
|
|
|
return io.ReadAll(resp.Body)
|
|
}
|
|
|
|
// userData construct user data from keycloak profile.
|
|
func (ap azureProfile) userData() *UserData {
|
|
id, ok := ap["id"].(string)
|
|
if !ok {
|
|
id = ""
|
|
}
|
|
|
|
email, ok := ap["userPrincipalName"].(string)
|
|
if !ok {
|
|
email = ""
|
|
}
|
|
|
|
name, ok := ap["displayName"].(string)
|
|
if !ok {
|
|
name = ""
|
|
}
|
|
|
|
return &UserData{
|
|
Email: email,
|
|
Name: name,
|
|
ID: id,
|
|
}
|
|
}
|