netbird/management/server/idp/azure.go
pascal-fischer 765aba2c1c
Add context to throughout the project and update logging (#2209)
propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-07-03 11:33:02 +02:00

460 lines
12 KiB
Go

package idp
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const profileFields = "id,displayName,mail,userPrincipalName"
// AzureManager azure manager client instance.
type AzureManager struct {
ClientID string
ObjectID string
GraphAPIEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// AzureClientConfig azure manager client configurations.
type AzureClientConfig struct {
ClientID string
ClientSecret string
ObjectID string
GraphAPIEndpoint string
TokenEndpoint string
GrantType string
}
// AzureCredentials azure authentication information.
type AzureCredentials struct {
clientConfig AzureClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// azureProfile represents an azure user profile.
type azureProfile map[string]any
// NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
}
if config.TokenEndpoint == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.GraphAPIEndpoint == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
}
if config.ObjectID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
}
credentials := &AzureCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &AzureManager{
ObjectID: config.ObjectID,
ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
func (ac *AzureCredentials) jwtStillValid() bool {
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
data.Set("grant_type", ac.clientConfig.GrantType)
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
if err != nil {
return nil, err
}
// get base url and add "/.default" as scope
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
scopeURL := baseURL + "/.default"
data.Set("scope", scopeURL)
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = ac.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = ac.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if ac.jwtStillValid() {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken(ctx)
if err != nil {
return ac.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
if err != nil {
return ac.jwtToken, err
}
ac.jwtToken = jwtToken
return ac.jwtToken, nil
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+userID, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
userData := profile.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+email, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, profile.userData())
return users, nil
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure.
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
q.Add("$select", profileFields)
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(ctx, nextLink, q)
if err != nil {
return nil, err
}
var profiles struct {
Value []azureProfile
NextLink string `json:"@odata.nextLink"`
}
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
for _, profile := range profiles.Value {
users = append(users, profile.userData())
}
nextLink = profiles.NextLink
}
return users, nil
}
// get perform Get requests.
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
var reqURL string
if strings.HasPrefix(resource, "https") {
// Already an absolute URL for paging
reqURL = resource
} else {
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
}
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile.
func (ap azureProfile) userData() *UserData {
id, ok := ap["id"].(string)
if !ok {
id = ""
}
email, ok := ap["userPrincipalName"].(string)
if !ok {
email = ""
}
name, ok := ap["displayName"].(string)
if !ok {
name = ""
}
return &UserData{
Email: email,
Name: name,
ID: id,
}
}