2022-03-20 08:29:18 +01:00
|
|
|
package oauth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
type mockHTTPClient struct {
|
|
|
|
code int
|
|
|
|
resBody string
|
|
|
|
reqBody string
|
|
|
|
MaxReqs int
|
|
|
|
count int
|
|
|
|
countResBody string
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
|
|
body, err := ioutil.ReadAll(req.Body)
|
|
|
|
if err == nil {
|
|
|
|
c.reqBody = string(body)
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.MaxReqs > c.count {
|
|
|
|
c.count++
|
|
|
|
return &http.Response{
|
|
|
|
StatusCode: c.code,
|
|
|
|
Body: ioutil.NopCloser(strings.NewReader(c.countResBody)),
|
|
|
|
}, c.err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &http.Response{
|
|
|
|
StatusCode: c.code,
|
|
|
|
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
|
|
|
|
}, c.err
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuth0_RequestDeviceCode(t *testing.T) {
|
|
|
|
type test struct {
|
|
|
|
name string
|
|
|
|
inputResBody string
|
|
|
|
inputReqCode int
|
|
|
|
inputReqError error
|
|
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
|
|
expectedErrorMSG string
|
|
|
|
testingFunc require.ComparisonAssertionFunc
|
|
|
|
expectedOut DeviceAuthInfo
|
|
|
|
expectedMSG string
|
|
|
|
expectPayload RequestDeviceCodePayload
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase1 := test{
|
|
|
|
name: "Payload Is Valid",
|
|
|
|
expectPayload: RequestDeviceCodePayload{
|
|
|
|
Audience: "ok",
|
|
|
|
ClientID: "bla",
|
|
|
|
},
|
|
|
|
inputReqCode: 200,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase2 := test{
|
|
|
|
name: "Exit On Network Error",
|
|
|
|
inputReqError: fmt.Errorf("error"),
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase3 := test{
|
|
|
|
name: "Exit On Exit Code",
|
|
|
|
inputReqCode: 400,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
|
|
|
testCase4 := test{
|
|
|
|
name: "Got Device Code",
|
|
|
|
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
|
|
|
expectPayload: RequestDeviceCodePayload{
|
|
|
|
Audience: "ok",
|
|
|
|
ClientID: "bla",
|
|
|
|
},
|
|
|
|
inputReqCode: 200,
|
|
|
|
testingErrFunc: require.NoError,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectedOut: testCase4Out,
|
|
|
|
expectedMSG: "out should match",
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
|
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
|
|
|
|
httpClient := mockHTTPClient{
|
|
|
|
resBody: testCase.inputResBody,
|
|
|
|
code: testCase.inputReqCode,
|
|
|
|
err: testCase.inputReqError,
|
|
|
|
}
|
|
|
|
|
|
|
|
auth0 := Auth0{
|
|
|
|
Audience: testCase.expectPayload.Audience,
|
|
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
|
|
Domain: "test.auth0.com",
|
|
|
|
HTTPClient: &httpClient,
|
|
|
|
}
|
|
|
|
|
|
|
|
authInfo, err := auth0.RequestDeviceCode(context.TODO())
|
|
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
|
|
|
|
payload, _ := json.Marshal(testCase.expectPayload)
|
|
|
|
|
|
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG)
|
|
|
|
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuth0_WaitToken(t *testing.T) {
|
|
|
|
type test struct {
|
|
|
|
name string
|
|
|
|
inputResBody string
|
|
|
|
inputReqCode int
|
|
|
|
inputReqError error
|
|
|
|
inputMaxReqs int
|
|
|
|
inputCountResBody string
|
|
|
|
inputTimeout time.Duration
|
|
|
|
inputInfo DeviceAuthInfo
|
|
|
|
inputAudience string
|
|
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
|
|
expectedErrorMSG string
|
|
|
|
testingFunc require.ComparisonAssertionFunc
|
|
|
|
expectedOut TokenInfo
|
|
|
|
expectedMSG string
|
|
|
|
expectPayload TokenRequestPayload
|
|
|
|
}
|
|
|
|
|
|
|
|
defaultInfo := DeviceAuthInfo{
|
|
|
|
DeviceCode: "test",
|
|
|
|
ExpiresIn: 10,
|
|
|
|
Interval: 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenReqPayload := TokenRequestPayload{
|
|
|
|
GrantType: auth0GrantType,
|
|
|
|
DeviceCode: defaultInfo.DeviceCode,
|
|
|
|
ClientID: "test",
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase1 := test{
|
|
|
|
name: "Payload Is Valid",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputReqCode: 200,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase2 := test{
|
|
|
|
name: "Exit On Network Error",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
inputReqError: fmt.Errorf("error"),
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase3 := test{
|
|
|
|
name: "Exit On 4XX When Not Pending",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputReqCode: 400,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase4 := test{
|
|
|
|
name: "Exit On Exit Code 5XX",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputReqCode: 500,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase5 := test{
|
|
|
|
name: "Exit On Content Timeout",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputTimeout: 0 * time.Second,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
audience := "test"
|
|
|
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
|
|
|
var hmacSampleSecret []byte
|
|
|
|
tokenString, _ := token.SignedString(hmacSampleSecret)
|
|
|
|
|
|
|
|
testCase6 := test{
|
|
|
|
name: "Exit On Invalid Audience",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputReqCode: 200,
|
|
|
|
inputAudience: "super test",
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase7 := test{
|
|
|
|
name: "Received Token Info",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputReqCode: 200,
|
|
|
|
inputAudience: audience,
|
|
|
|
testingErrFunc: require.NoError,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase8 := test{
|
|
|
|
name: "Received Token Info after Multiple tries",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
|
|
inputMaxReqs: 2,
|
|
|
|
inputCountResBody: "{\"error\":\"authorization_pending\"}",
|
|
|
|
inputReqCode: 200,
|
|
|
|
inputAudience: audience,
|
|
|
|
testingErrFunc: require.NoError,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7, testCase8} {
|
|
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
|
|
|
|
httpClient := mockHTTPClient{
|
|
|
|
resBody: testCase.inputResBody,
|
|
|
|
code: testCase.inputReqCode,
|
|
|
|
err: testCase.inputReqError,
|
|
|
|
MaxReqs: testCase.inputMaxReqs,
|
|
|
|
countResBody: testCase.inputCountResBody,
|
|
|
|
}
|
|
|
|
|
|
|
|
auth0 := Auth0{
|
|
|
|
Audience: testCase.inputAudience,
|
|
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
|
|
Domain: "test.auth0.com",
|
|
|
|
HTTPClient: &httpClient,
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
|
|
|
defer cancel()
|
|
|
|
tokenInfo, err := auth0.WaitToken(ctx, testCase.inputInfo)
|
|
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
|
|
|
|
var payload []byte
|
|
|
|
var emptyPayload TokenRequestPayload
|
|
|
|
if testCase.expectPayload != emptyPayload {
|
|
|
|
payload, _ = json.Marshal(testCase.expectPayload)
|
|
|
|
}
|
|
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
|
|
|
|
|
|
|
require.GreaterOrEqualf(t, testCase.inputMaxReqs, httpClient.count, "should run %d times", testCase.inputMaxReqs)
|
|
|
|
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2022-03-22 13:12:11 +01:00
|
|
|
|
|
|
|
func TestAuth0_RotateAccessToken(t *testing.T) {
|
|
|
|
type test struct {
|
|
|
|
name string
|
|
|
|
inputResBody string
|
|
|
|
inputReqCode int
|
|
|
|
inputReqError error
|
|
|
|
inputMaxReqs int
|
|
|
|
inputInfo DeviceAuthInfo
|
|
|
|
inputAudience string
|
|
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
|
|
expectedErrorMSG string
|
|
|
|
testingFunc require.ComparisonAssertionFunc
|
|
|
|
expectedOut TokenInfo
|
|
|
|
expectedMSG string
|
|
|
|
expectPayload TokenRequestPayload
|
|
|
|
}
|
|
|
|
|
|
|
|
defaultInfo := DeviceAuthInfo{
|
|
|
|
DeviceCode: "test",
|
|
|
|
ExpiresIn: 10,
|
|
|
|
Interval: 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenReqPayload := TokenRequestPayload{
|
|
|
|
GrantType: auth0RefreshGrant,
|
|
|
|
ClientID: "test",
|
|
|
|
RefreshToken: "refresh_test",
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase1 := test{
|
|
|
|
name: "Payload Is Valid",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputReqCode: 200,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase2 := test{
|
|
|
|
name: "Exit On Network Error",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
inputReqError: fmt.Errorf("error"),
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase3 := test{
|
|
|
|
name: "Exit On Non 200 Status Code",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputReqCode: 401,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
expectedErrorMSG: "should return error",
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
}
|
|
|
|
|
|
|
|
audience := "test"
|
|
|
|
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
|
|
|
var hmacSampleSecret []byte
|
|
|
|
tokenString, _ := token.SignedString(hmacSampleSecret)
|
|
|
|
|
|
|
|
testCase4 := test{
|
|
|
|
name: "Exit On Invalid Audience",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
|
|
inputReqCode: 200,
|
|
|
|
inputAudience: "super test",
|
|
|
|
testingErrFunc: require.Error,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
}
|
|
|
|
|
|
|
|
testCase5 := test{
|
|
|
|
name: "Received Token Info",
|
|
|
|
inputInfo: defaultInfo,
|
|
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
|
|
inputReqCode: 200,
|
|
|
|
inputAudience: audience,
|
|
|
|
testingErrFunc: require.NoError,
|
|
|
|
testingFunc: require.EqualValues,
|
|
|
|
expectPayload: tokenReqPayload,
|
|
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
|
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
|
|
|
|
httpClient := mockHTTPClient{
|
|
|
|
resBody: testCase.inputResBody,
|
|
|
|
code: testCase.inputReqCode,
|
|
|
|
err: testCase.inputReqError,
|
|
|
|
MaxReqs: testCase.inputMaxReqs,
|
|
|
|
}
|
|
|
|
|
|
|
|
auth0 := Auth0{
|
|
|
|
Audience: testCase.inputAudience,
|
|
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
|
|
Domain: "test.auth0.com",
|
|
|
|
HTTPClient: &httpClient,
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenInfo, err := auth0.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
|
|
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
|
|
|
|
var payload []byte
|
|
|
|
var emptyPayload TokenRequestPayload
|
|
|
|
if testCase.expectPayload != emptyPayload {
|
|
|
|
payload, _ = json.Marshal(testCase.expectPayload)
|
|
|
|
}
|
|
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
|
|
|
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|