package internal import ( "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "reflect" "strings" "time" ) // OAuthClient is a OAuth client interface for various idp providers type OAuthClient interface { RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) GetClientID(ctx context.Context) string } // HTTPClient http client interface for API calls type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } // DeviceAuthInfo holds information for the OAuth device login flow type DeviceAuthInfo struct { DeviceCode string `json:"device_code"` UserCode string `json:"user_code"` VerificationURI string `json:"verification_uri"` VerificationURIComplete string `json:"verification_uri_complete"` ExpiresIn int `json:"expires_in"` Interval int `json:"interval"` } // TokenInfo holds information of issued access token type TokenInfo struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` IDToken string `json:"id_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` } // HostedGrantType grant type for device flow on Hosted const ( HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code" HostedRefreshGrant = "refresh_token" ) // Hosted client type Hosted struct { // Hosted API Audience for validation Audience string // Hosted Native application client id ClientID string // Hosted Native application request scope Scope string // TokenEndpoint to request access token TokenEndpoint string // DeviceAuthEndpoint to request device authorization code DeviceAuthEndpoint string HTTPClient HTTPClient } // RequestDeviceCodePayload used for request device code payload for auth0 type RequestDeviceCodePayload struct { Audience string `json:"audience"` ClientID string `json:"client_id"` Scope string `json:"scope"` } // TokenRequestPayload used for requesting the auth0 token type TokenRequestPayload struct { GrantType string `json:"grant_type"` DeviceCode string `json:"device_code,omitempty"` ClientID string `json:"client_id"` RefreshToken string `json:"refresh_token,omitempty"` } // TokenRequestResponse used for parsing Hosted token's response type TokenRequestResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` TokenInfo } // Claims used when validating the access token type Claims struct { Audience interface{} `json:"aud"` } // NewHostedDeviceFlow returns an Hosted OAuth client func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 httpClient := &http.Client{ Timeout: 10 * time.Second, Transport: httpTransport, } return &Hosted{ Audience: audience, ClientID: clientID, Scope: "openid", TokenEndpoint: tokenEndpoint, HTTPClient: httpClient, DeviceAuthEndpoint: deviceAuthEndpoint, } } // GetClientID returns the provider client id func (h *Hosted) GetClientID(ctx context.Context) string { return h.ClientID } // RequestDeviceCode requests a device code login flow information from Hosted func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { form := url.Values{} form.Add("client_id", h.ClientID) form.Add("audience", h.Audience) form.Add("scope", h.Scope) req, err := http.NewRequest("POST", h.DeviceAuthEndpoint, strings.NewReader(form.Encode())) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") res, err := h.HTTPClient.Do(req) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err) } if res.StatusCode != 200 { return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body)) } deviceCode := DeviceAuthInfo{} err = json.Unmarshal(body, &deviceCode) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err) } return deviceCode, err } func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", h.ClientID) form.Add("grant_type", HostedGrantType) form.Add("device_code", info.DeviceCode) req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err) } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") res, err := h.HTTPClient.Do(req) if err != nil { return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err) } defer func() { err := res.Body.Close() if err != nil { return } }() body, err := io.ReadAll(res.Body) if err != nil { return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err) } if res.StatusCode > 499 { return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body)) } tokenResponse := TokenRequestResponse{} err = json.Unmarshal(body, &tokenResponse) if err != nil { return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err) } return tokenResponse, nil } // WaitToken waits user's login and authorize the app. Once the user's authorize // it retrieves the access token from Hosted's endpoint and validates it before returning func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) { interval := time.Duration(info.Interval) * time.Second ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): return TokenInfo{}, ctx.Err() case <-ticker.C: tokenResponse, err := h.requestToken(info) if err != nil { return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err) } if tokenResponse.Error != "" { if tokenResponse.Error == "authorization_pending" { continue } else if tokenResponse.Error == "slow_down" { interval = interval + (3 * time.Second) ticker.Reset(interval) continue } return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) } err = isValidAccessToken(tokenResponse.AccessToken, h.Audience) if err != nil { return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) } tokenInfo := TokenInfo{ AccessToken: tokenResponse.AccessToken, TokenType: tokenResponse.TokenType, RefreshToken: tokenResponse.RefreshToken, IDToken: tokenResponse.IDToken, ExpiresIn: tokenResponse.ExpiresIn, } return tokenInfo, err } } } // isValidAccessToken is a simple validation of the access token func isValidAccessToken(token string, audience string) error { if token == "" { return fmt.Errorf("token received is empty") } encodedClaims := strings.Split(token, ".")[1] claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims) if err != nil { return err } claims := Claims{} err = json.Unmarshal(claimsString, &claims) if err != nil { return err } if claims.Audience == nil { return fmt.Errorf("required token field audience is absent") } // Audience claim of JWT can be a string or an array of strings typ := reflect.TypeOf(claims.Audience) switch typ.Kind() { case reflect.String: if claims.Audience == audience { return nil } case reflect.Slice: for _, aud := range claims.Audience.([]interface{}) { if audience == aud { return nil } } } return fmt.Errorf("invalid JWT token audience field") }