mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-20 21:08:45 +01:00
3def84b111
Support Generic OAuth 2.0 Device Authorization Grant as per RFC specification https://www.rfc-editor.org/rfc/rfc8628. The previous version supported only Auth0 as an IDP backend. This implementation enables the Interactive SSO Login feature for any IDP compatible with the specification, e.g., Keycloak.
279 lines
7.8 KiB
Go
279 lines
7.8 KiB
Go
package internal
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// OAuthClient is a OAuth client interface for various idp providers
|
|
type OAuthClient interface {
|
|
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
|
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
|
GetClientID(ctx context.Context) string
|
|
}
|
|
|
|
// HTTPClient http client interface for API calls
|
|
type HTTPClient interface {
|
|
Do(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
// DeviceAuthInfo holds information for the OAuth device login flow
|
|
type DeviceAuthInfo struct {
|
|
DeviceCode string `json:"device_code"`
|
|
UserCode string `json:"user_code"`
|
|
VerificationURI string `json:"verification_uri"`
|
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Interval int `json:"interval"`
|
|
}
|
|
|
|
// TokenInfo holds information of issued access token
|
|
type TokenInfo struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
IDToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
// HostedGrantType grant type for device flow on Hosted
|
|
const (
|
|
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
|
HostedRefreshGrant = "refresh_token"
|
|
)
|
|
|
|
// Hosted client
|
|
type Hosted struct {
|
|
// Hosted API Audience for validation
|
|
Audience string
|
|
// Hosted Native application client id
|
|
ClientID string
|
|
// TokenEndpoint to request access token
|
|
TokenEndpoint string
|
|
// DeviceAuthEndpoint to request device authorization code
|
|
DeviceAuthEndpoint string
|
|
|
|
HTTPClient HTTPClient
|
|
}
|
|
|
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
|
type RequestDeviceCodePayload struct {
|
|
Audience string `json:"audience"`
|
|
ClientID string `json:"client_id"`
|
|
}
|
|
|
|
// TokenRequestPayload used for requesting the auth0 token
|
|
type TokenRequestPayload struct {
|
|
GrantType string `json:"grant_type"`
|
|
DeviceCode string `json:"device_code,omitempty"`
|
|
ClientID string `json:"client_id"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
}
|
|
|
|
// TokenRequestResponse used for parsing Hosted token's response
|
|
type TokenRequestResponse struct {
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
TokenInfo
|
|
}
|
|
|
|
// Claims used when validating the access token
|
|
type Claims struct {
|
|
Audience interface{} `json:"aud"`
|
|
}
|
|
|
|
// NewHostedDeviceFlow returns an Hosted OAuth client
|
|
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
|
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
httpTransport.MaxIdleConns = 5
|
|
|
|
httpClient := &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
Transport: httpTransport,
|
|
}
|
|
|
|
return &Hosted{
|
|
Audience: audience,
|
|
ClientID: clientID,
|
|
TokenEndpoint: tokenEndpoint,
|
|
HTTPClient: httpClient,
|
|
DeviceAuthEndpoint: deviceAuthEndpoint,
|
|
}
|
|
}
|
|
|
|
// GetClientID returns the provider client id
|
|
func (h *Hosted) GetClientID(ctx context.Context) string {
|
|
return h.ClientID
|
|
}
|
|
|
|
// RequestDeviceCode requests a device code login flow information from Hosted
|
|
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
|
form := url.Values{}
|
|
form.Add("client_id", h.ClientID)
|
|
form.Add("audience", h.Audience)
|
|
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
|
|
strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
|
}
|
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
res, err := h.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err)
|
|
}
|
|
|
|
defer res.Body.Close()
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
|
}
|
|
|
|
if res.StatusCode != 200 {
|
|
return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body))
|
|
}
|
|
|
|
deviceCode := DeviceAuthInfo{}
|
|
err = json.Unmarshal(body, &deviceCode)
|
|
if err != nil {
|
|
return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err)
|
|
}
|
|
|
|
return deviceCode, err
|
|
}
|
|
|
|
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
|
form := url.Values{}
|
|
form.Add("client_id", h.ClientID)
|
|
form.Add("grant_type", HostedGrantType)
|
|
form.Add("device_code", info.DeviceCode)
|
|
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
|
}
|
|
|
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
res, err := h.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
|
}
|
|
|
|
defer func() {
|
|
err := res.Body.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}()
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
|
}
|
|
|
|
if res.StatusCode > 499 {
|
|
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
|
}
|
|
|
|
tokenResponse := TokenRequestResponse{}
|
|
err = json.Unmarshal(body, &tokenResponse)
|
|
if err != nil {
|
|
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
}
|
|
|
|
return tokenResponse, nil
|
|
}
|
|
|
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
|
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
|
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
|
interval := time.Duration(info.Interval) * time.Second
|
|
ticker := time.NewTicker(interval)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return TokenInfo{}, ctx.Err()
|
|
case <-ticker.C:
|
|
|
|
tokenResponse, err := h.requestToken(info)
|
|
if err != nil {
|
|
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
|
}
|
|
|
|
if tokenResponse.Error != "" {
|
|
if tokenResponse.Error == "authorization_pending" {
|
|
continue
|
|
} else if tokenResponse.Error == "slow_down" {
|
|
interval = interval + (3 * time.Second)
|
|
ticker.Reset(interval)
|
|
continue
|
|
}
|
|
|
|
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
|
}
|
|
|
|
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
|
|
if err != nil {
|
|
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
|
}
|
|
|
|
tokenInfo := TokenInfo{
|
|
AccessToken: tokenResponse.AccessToken,
|
|
TokenType: tokenResponse.TokenType,
|
|
RefreshToken: tokenResponse.RefreshToken,
|
|
IDToken: tokenResponse.IDToken,
|
|
ExpiresIn: tokenResponse.ExpiresIn,
|
|
}
|
|
return tokenInfo, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// isValidAccessToken is a simple validation of the access token
|
|
func isValidAccessToken(token string, audience string) error {
|
|
if token == "" {
|
|
return fmt.Errorf("token received is empty")
|
|
}
|
|
|
|
encodedClaims := strings.Split(token, ".")[1]
|
|
claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
claims := Claims{}
|
|
err = json.Unmarshal(claimsString, &claims)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if claims.Audience == nil {
|
|
return fmt.Errorf("required token field audience is absent")
|
|
}
|
|
|
|
// Audience claim of JWT can be a string or an array of strings
|
|
typ := reflect.TypeOf(claims.Audience)
|
|
switch typ.Kind() {
|
|
case reflect.String:
|
|
if claims.Audience == audience {
|
|
return nil
|
|
}
|
|
case reflect.Slice:
|
|
for _, aud := range claims.Audience.([]interface{}) {
|
|
if audience == aud {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("invalid JWT token audience field")
|
|
}
|