package internal import ( "context" "encoding/json" "fmt" "github.com/golang-jwt/jwt" "github.com/stretchr/testify/require" "io/ioutil" "net/http" "strings" "testing" "time" ) type mockHTTPClient struct { code int resBody string reqBody string MaxReqs int count int countResBody string err error } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { body, err := ioutil.ReadAll(req.Body) if err == nil { c.reqBody = string(body) } if c.MaxReqs > c.count { c.count++ return &http.Response{ StatusCode: c.code, Body: ioutil.NopCloser(strings.NewReader(c.countResBody)), }, c.err } return &http.Response{ StatusCode: c.code, Body: ioutil.NopCloser(strings.NewReader(c.resBody)), }, c.err } func TestHosted_RequestDeviceCode(t *testing.T) { type test struct { name string inputResBody string inputReqCode int inputReqError error testingErrFunc require.ErrorAssertionFunc expectedErrorMSG string testingFunc require.ComparisonAssertionFunc expectedOut DeviceAuthInfo expectedMSG string expectPayload RequestDeviceCodePayload } testCase1 := test{ name: "Payload Is Valid", expectPayload: RequestDeviceCodePayload{ Audience: "ok", ClientID: "bla", }, inputReqCode: 200, testingErrFunc: require.Error, testingFunc: require.EqualValues, } testCase2 := test{ name: "Exit On Network Error", inputReqError: fmt.Errorf("error"), testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase3 := test{ name: "Exit On Exit Code", inputReqCode: 400, testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase4Out := DeviceAuthInfo{ExpiresIn: 10} testCase4 := test{ name: "Got Device Code", inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn), expectPayload: RequestDeviceCodePayload{ Audience: "ok", ClientID: "bla", }, inputReqCode: 200, testingErrFunc: require.NoError, testingFunc: require.EqualValues, expectedOut: testCase4Out, expectedMSG: "out should match", } for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { t.Run(testCase.name, func(t *testing.T) { httpClient := mockHTTPClient{ resBody: testCase.inputResBody, code: testCase.inputReqCode, err: testCase.inputReqError, } hosted := Hosted{ Audience: testCase.expectPayload.Audience, ClientID: testCase.expectPayload.ClientID, Domain: "test.hosted.com", HTTPClient: &httpClient, } authInfo, err := hosted.RequestDeviceCode(context.TODO()) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) payload, _ := json.Marshal(testCase.expectPayload) require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match") testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG) }) } } func TestHosted_WaitToken(t *testing.T) { type test struct { name string inputResBody string inputReqCode int inputReqError error inputMaxReqs int inputCountResBody string inputTimeout time.Duration inputInfo DeviceAuthInfo inputAudience string testingErrFunc require.ErrorAssertionFunc expectedErrorMSG string testingFunc require.ComparisonAssertionFunc expectedOut TokenInfo expectedMSG string expectPayload TokenRequestPayload } defaultInfo := DeviceAuthInfo{ DeviceCode: "test", ExpiresIn: 10, Interval: 1, } tokenReqPayload := TokenRequestPayload{ GrantType: HostedGrantType, DeviceCode: defaultInfo.DeviceCode, ClientID: "test", } testCase1 := test{ name: "Payload Is Valid", inputInfo: defaultInfo, inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputReqCode: 200, testingErrFunc: require.Error, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, } testCase2 := test{ name: "Exit On Network Error", inputInfo: defaultInfo, inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, expectPayload: tokenReqPayload, inputReqError: fmt.Errorf("error"), testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase3 := test{ name: "Exit On 4XX When Not Pending", inputInfo: defaultInfo, inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputReqCode: 400, expectPayload: tokenReqPayload, testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase4 := test{ name: "Exit On Exit Code 5XX", inputInfo: defaultInfo, inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputReqCode: 500, expectPayload: tokenReqPayload, testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase5 := test{ name: "Exit On Content Timeout", inputInfo: defaultInfo, inputTimeout: 0 * time.Second, testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } audience := "test" token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience}) var hmacSampleSecret []byte tokenString, _ := token.SignedString(hmacSampleSecret) testCase6 := test{ name: "Exit On Invalid Audience", inputInfo: defaultInfo, inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputReqCode: 200, inputAudience: "super test", testingErrFunc: require.Error, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, } testCase7 := test{ name: "Received Token Info", inputInfo: defaultInfo, inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputReqCode: 200, inputAudience: audience, testingErrFunc: require.NoError, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, expectedOut: TokenInfo{AccessToken: tokenString}, } testCase8 := test{ name: "Received Token Info after Multiple tries", inputInfo: defaultInfo, inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, inputMaxReqs: 2, inputCountResBody: "{\"error\":\"authorization_pending\"}", inputReqCode: 200, inputAudience: audience, testingErrFunc: require.NoError, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, expectedOut: TokenInfo{AccessToken: tokenString}, } for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7, testCase8} { t.Run(testCase.name, func(t *testing.T) { httpClient := mockHTTPClient{ resBody: testCase.inputResBody, code: testCase.inputReqCode, err: testCase.inputReqError, MaxReqs: testCase.inputMaxReqs, countResBody: testCase.inputCountResBody, } hosted := Hosted{ Audience: testCase.inputAudience, ClientID: testCase.expectPayload.ClientID, Domain: "test.hosted.com", HTTPClient: &httpClient, } ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) var payload []byte var emptyPayload TokenRequestPayload if testCase.expectPayload != emptyPayload { payload, _ = json.Marshal(testCase.expectPayload) } require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match") testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG) require.GreaterOrEqualf(t, testCase.inputMaxReqs, httpClient.count, "should run %d times", testCase.inputMaxReqs) }) } } func TestHosted_RotateAccessToken(t *testing.T) { type test struct { name string inputResBody string inputReqCode int inputReqError error inputMaxReqs int inputInfo DeviceAuthInfo inputAudience string testingErrFunc require.ErrorAssertionFunc expectedErrorMSG string testingFunc require.ComparisonAssertionFunc expectedOut TokenInfo expectedMSG string expectPayload TokenRequestPayload } defaultInfo := DeviceAuthInfo{ DeviceCode: "test", ExpiresIn: 10, Interval: 1, } tokenReqPayload := TokenRequestPayload{ GrantType: HostedRefreshGrant, ClientID: "test", RefreshToken: "refresh_test", } testCase1 := test{ name: "Payload Is Valid", inputInfo: defaultInfo, inputReqCode: 200, testingErrFunc: require.Error, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, } testCase2 := test{ name: "Exit On Network Error", inputInfo: defaultInfo, expectPayload: tokenReqPayload, inputReqError: fmt.Errorf("error"), testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } testCase3 := test{ name: "Exit On Non 200 Status Code", inputInfo: defaultInfo, inputReqCode: 401, expectPayload: tokenReqPayload, testingErrFunc: require.Error, expectedErrorMSG: "should return error", testingFunc: require.EqualValues, } audience := "test" token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience}) var hmacSampleSecret []byte tokenString, _ := token.SignedString(hmacSampleSecret) testCase4 := test{ name: "Exit On Invalid Audience", inputInfo: defaultInfo, inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), inputReqCode: 200, inputAudience: "super test", testingErrFunc: require.Error, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, } testCase5 := test{ name: "Received Token Info", inputInfo: defaultInfo, inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), inputReqCode: 200, inputAudience: audience, testingErrFunc: require.NoError, testingFunc: require.EqualValues, expectPayload: tokenReqPayload, expectedOut: TokenInfo{AccessToken: tokenString}, } for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { t.Run(testCase.name, func(t *testing.T) { httpClient := mockHTTPClient{ resBody: testCase.inputResBody, code: testCase.inputReqCode, err: testCase.inputReqError, MaxReqs: testCase.inputMaxReqs, } hosted := Hosted{ Audience: testCase.inputAudience, ClientID: testCase.expectPayload.ClientID, Domain: "test.hosted.com", HTTPClient: &httpClient, } tokenInfo, err := hosted.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) var payload []byte var emptyPayload TokenRequestPayload if testCase.expectPayload != emptyPayload { payload, _ = json.Marshal(testCase.expectPayload) } require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match") testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG) }) } }